#![allow(missing_docs)]
use std::fmt;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::crypto::pqc::types::*;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub growth_increment: usize,
pub acquire_timeout: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
initial_size: 4,
max_size: 100,
growth_increment: 4,
acquire_timeout: Duration::from_secs(5),
}
}
}
#[derive(Debug, Default)]
pub struct PoolStats {
pub allocations: AtomicU64,
pub deallocations: AtomicU64,
pub hits: AtomicU64,
pub misses: AtomicU64,
pub current_size: AtomicUsize,
}
impl PoolStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed) as f64;
let total = hits + self.misses.load(Ordering::Relaxed) as f64;
if total > 0.0 {
(hits / total) * 100.0
} else {
0.0
}
}
}
#[derive(Clone)]
pub struct MlKemPublicKeyBuffer(pub Box<[u8; ML_KEM_768_PUBLIC_KEY_SIZE]>);
#[derive(Clone)]
pub struct MlKemSecretKeyBuffer(pub Box<[u8; ML_KEM_768_SECRET_KEY_SIZE]>);
#[derive(Clone)]
pub struct MlKemCiphertextBuffer(pub Box<[u8; ML_KEM_768_CIPHERTEXT_SIZE]>);
#[derive(Clone)]
pub struct MlDsaPublicKeyBuffer(pub Box<[u8; ML_DSA_65_PUBLIC_KEY_SIZE]>);
#[derive(Clone)]
pub struct MlDsaSecretKeyBuffer(pub Box<[u8; ML_DSA_65_SECRET_KEY_SIZE]>);
#[derive(Clone)]
pub struct MlDsaSignatureBuffer(pub Box<[u8; ML_DSA_65_SIGNATURE_SIZE]>);
pub trait BufferCleanup {
fn cleanup(&mut self);
}
impl BufferCleanup for MlKemPublicKeyBuffer {
fn cleanup(&mut self) {}
}
impl BufferCleanup for MlKemCiphertextBuffer {
fn cleanup(&mut self) {}
}
impl BufferCleanup for MlDsaPublicKeyBuffer {
fn cleanup(&mut self) {}
}
impl BufferCleanup for MlDsaSignatureBuffer {
fn cleanup(&mut self) {}
}
impl BufferCleanup for MlKemSecretKeyBuffer {
fn cleanup(&mut self) {
self.0.fill(0);
}
}
impl BufferCleanup for MlDsaSecretKeyBuffer {
fn cleanup(&mut self) {
self.0.fill(0);
}
}
struct ObjectPool<T: BufferCleanup> {
available: Arc<Mutex<Vec<T>>>,
config: PoolConfig,
stats: Arc<PoolStats>,
factory: Box<dyn Fn() -> T + Send + Sync>,
}
impl<T: BufferCleanup> ObjectPool<T> {
fn new<F>(config: PoolConfig, stats: Arc<PoolStats>, factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
let mut available = Vec::with_capacity(config.initial_size);
for _ in 0..config.initial_size {
available.push(factory());
}
stats
.current_size
.store(config.initial_size, Ordering::Relaxed);
Self {
available: Arc::new(Mutex::new(available)),
config,
stats,
factory: Box::new(factory),
}
}
fn acquire(&self) -> Result<PoolGuard<T>, PqcError> {
let mut available = self
.available
.lock()
.map_err(|_| PqcError::PoolError("Failed to lock pool".to_string()))?;
self.stats.allocations.fetch_add(1, Ordering::Relaxed);
let object = match available.pop() {
Some(obj) => {
self.stats.hits.fetch_add(1, Ordering::Relaxed);
obj
}
_ => {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
let current_size = self.stats.current_size.load(Ordering::Relaxed);
if current_size >= self.config.max_size {
return Err(PqcError::PoolError("Pool at maximum capacity".to_string()));
}
self.stats.current_size.fetch_add(1, Ordering::Relaxed);
(self.factory)()
}
};
Ok(PoolGuard {
object: Some(object),
pool: self.available.clone(),
stats: self.stats.clone(),
})
}
fn available_count(&self) -> usize {
self.available.lock().map(|guard| guard.len()).unwrap_or(0)
}
}
pub struct PoolGuard<T: BufferCleanup> {
object: Option<T>,
pool: Arc<Mutex<Vec<T>>>,
stats: Arc<PoolStats>,
}
impl<T: BufferCleanup> PoolGuard<T> {
#[allow(clippy::unwrap_used)]
pub fn as_ref(&self) -> &T {
self.object.as_ref().unwrap() }
#[allow(clippy::unwrap_used)]
pub fn as_mut(&mut self) -> &mut T {
self.object.as_mut().unwrap() }
}
impl<T: BufferCleanup> Drop for PoolGuard<T> {
fn drop(&mut self) {
if let Some(mut object) = self.object.take() {
object.cleanup();
self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
if let Ok(mut available) = self.pool.lock() {
available.push(object);
}
}
}
}
impl Drop for MlKemSecretKeyBuffer {
fn drop(&mut self) {
self.0.as_mut().fill(0);
}
}
impl Drop for MlDsaSecretKeyBuffer {
fn drop(&mut self) {
self.0.as_mut().fill(0);
}
}
pub struct PqcMemoryPool {
ml_kem_public_keys: ObjectPool<MlKemPublicKeyBuffer>,
ml_kem_secret_keys: ObjectPool<MlKemSecretKeyBuffer>,
ml_kem_ciphertexts: ObjectPool<MlKemCiphertextBuffer>,
ml_dsa_public_keys: ObjectPool<MlDsaPublicKeyBuffer>,
ml_dsa_secret_keys: ObjectPool<MlDsaSecretKeyBuffer>,
ml_dsa_signatures: ObjectPool<MlDsaSignatureBuffer>,
stats: Arc<PoolStats>,
}
impl PqcMemoryPool {
pub fn new(config: PoolConfig) -> Self {
let stats = Arc::new(PoolStats::default());
Self {
ml_kem_public_keys: ObjectPool::new(config.clone(), stats.clone(), || {
MlKemPublicKeyBuffer(Box::new([0u8; ML_KEM_768_PUBLIC_KEY_SIZE]))
}),
ml_kem_secret_keys: ObjectPool::new(config.clone(), stats.clone(), || {
MlKemSecretKeyBuffer(Box::new([0u8; ML_KEM_768_SECRET_KEY_SIZE]))
}),
ml_kem_ciphertexts: ObjectPool::new(config.clone(), stats.clone(), || {
MlKemCiphertextBuffer(Box::new([0u8; ML_KEM_768_CIPHERTEXT_SIZE]))
}),
ml_dsa_public_keys: ObjectPool::new(config.clone(), stats.clone(), || {
MlDsaPublicKeyBuffer(Box::new([0u8; ML_DSA_65_PUBLIC_KEY_SIZE]))
}),
ml_dsa_secret_keys: ObjectPool::new(config.clone(), stats.clone(), || {
MlDsaSecretKeyBuffer(Box::new([0u8; ML_DSA_65_SECRET_KEY_SIZE]))
}),
ml_dsa_signatures: ObjectPool::new(config, stats.clone(), || {
MlDsaSignatureBuffer(Box::new([0u8; ML_DSA_65_SIGNATURE_SIZE]))
}),
stats,
}
}
pub fn acquire_ml_kem_public_key(&self) -> Result<PoolGuard<MlKemPublicKeyBuffer>, PqcError> {
self.ml_kem_public_keys.acquire()
}
pub fn acquire_ml_kem_secret_key(&self) -> Result<PoolGuard<MlKemSecretKeyBuffer>, PqcError> {
self.ml_kem_secret_keys.acquire()
}
pub fn acquire_ml_kem_ciphertext(&self) -> Result<PoolGuard<MlKemCiphertextBuffer>, PqcError> {
self.ml_kem_ciphertexts.acquire()
}
pub fn acquire_ml_dsa_public_key(&self) -> Result<PoolGuard<MlDsaPublicKeyBuffer>, PqcError> {
self.ml_dsa_public_keys.acquire()
}
pub fn acquire_ml_dsa_secret_key(&self) -> Result<PoolGuard<MlDsaSecretKeyBuffer>, PqcError> {
self.ml_dsa_secret_keys.acquire()
}
pub fn acquire_ml_dsa_signature(&self) -> Result<PoolGuard<MlDsaSignatureBuffer>, PqcError> {
self.ml_dsa_signatures.acquire()
}
pub fn stats(&self) -> &PoolStats {
&self.stats
}
#[cfg(test)]
pub fn available_count(&self) -> usize {
self.ml_kem_public_keys.available_count()
}
}
impl fmt::Debug for PqcMemoryPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PqcMemoryPool")
.field(
"ml_kem_public_keys",
&self.ml_kem_public_keys.available_count(),
)
.field(
"ml_kem_secret_keys",
&self.ml_kem_secret_keys.available_count(),
)
.field(
"ml_kem_ciphertexts",
&self.ml_kem_ciphertexts.available_count(),
)
.field(
"ml_dsa_public_keys",
&self.ml_dsa_public_keys.available_count(),
)
.field(
"ml_dsa_secret_keys",
&self.ml_dsa_secret_keys.available_count(),
)
.field(
"ml_dsa_signatures",
&self.ml_dsa_signatures.available_count(),
)
.field("hit_rate", &format!("{:.1}%", self.stats.hit_rate()))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_pool_reuses_objects() {
let pool = PqcMemoryPool::new(PoolConfig::default());
let guard1 = pool.acquire_ml_kem_public_key().unwrap();
let ptr1 = guard1.as_ref().0.as_ptr();
drop(guard1);
let guard2 = pool.acquire_ml_kem_public_key().unwrap();
let ptr2 = guard2.as_ref().0.as_ptr();
assert_eq!(ptr1, ptr2, "Pool should reuse the same buffer");
}
#[tokio::test]
async fn test_concurrent_pool_access() {
let pool = Arc::new(PqcMemoryPool::new(PoolConfig {
initial_size: 2,
max_size: 10,
growth_increment: 1,
acquire_timeout: Duration::from_secs(1),
}));
let mut handles = vec![];
for _ in 0..10 {
let pool_clone = pool.clone();
handles.push(tokio::spawn(async move {
let _guard = pool_clone.acquire_ml_kem_secret_key().unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
}));
}
for handle in handles {
handle.await.unwrap();
}
let current_size = pool.stats().current_size.load(Ordering::Relaxed);
assert_eq!(current_size, 10, "Pool should have grown to 10 objects");
}
#[test]
fn test_guard_auto_returns_on_drop() {
let pool = PqcMemoryPool::new(PoolConfig::default());
let initial_available = pool.available_count();
{
let _guard = pool.acquire_ml_kem_ciphertext().unwrap();
assert_eq!(
pool.ml_kem_ciphertexts.available_count(),
initial_available - 1
);
}
assert_eq!(pool.ml_kem_ciphertexts.available_count(), initial_available);
}
#[test]
fn test_pool_respects_max_size() {
let pool = PqcMemoryPool::new(PoolConfig {
initial_size: 1,
max_size: 2,
growth_increment: 1,
acquire_timeout: Duration::from_secs(1),
});
let _guard1 = pool.acquire_ml_dsa_signature().unwrap();
let _guard2 = pool.acquire_ml_dsa_signature().unwrap();
let result = pool.acquire_ml_dsa_signature();
assert!(result.is_err());
assert!(matches!(result, Err(PqcError::PoolError(_))));
}
#[test]
fn test_pool_statistics() {
let pool = PqcMemoryPool::new(PoolConfig {
initial_size: 2,
max_size: 10,
growth_increment: 1,
acquire_timeout: Duration::from_secs(1),
});
let guard1 = pool.acquire_ml_kem_public_key().unwrap();
let guard2 = pool.acquire_ml_kem_public_key().unwrap();
assert_eq!(pool.stats().hits.load(Ordering::Relaxed), 2);
assert_eq!(pool.stats().misses.load(Ordering::Relaxed), 0);
let _guard3 = pool.acquire_ml_kem_public_key().unwrap();
assert_eq!(pool.stats().hits.load(Ordering::Relaxed), 2);
assert_eq!(pool.stats().misses.load(Ordering::Relaxed), 1);
drop(guard1);
drop(guard2);
assert_eq!(pool.stats().deallocations.load(Ordering::Relaxed), 2);
}
#[test]
fn test_secret_key_zeroization() {
let pool = PqcMemoryPool::new(PoolConfig::default());
{
let mut guard = pool.acquire_ml_kem_secret_key().unwrap();
guard.as_mut().0.fill(0xFF);
}
{
let mut guard = pool.acquire_ml_dsa_secret_key().unwrap();
guard.as_mut().0.fill(0xFF);
}
let guard = pool.acquire_ml_kem_secret_key().unwrap();
assert!(
guard.as_ref().0.iter().all(|&b| b == 0),
"Secret key buffer should be zeroed"
);
}
#[test]
fn test_all_buffer_types() {
let pool = PqcMemoryPool::new(PoolConfig::default());
let ml_kem_pk = pool.acquire_ml_kem_public_key().unwrap();
assert_eq!(ml_kem_pk.as_ref().0.len(), ML_KEM_768_PUBLIC_KEY_SIZE);
let ml_kem_sk = pool.acquire_ml_kem_secret_key().unwrap();
assert_eq!(ml_kem_sk.as_ref().0.len(), ML_KEM_768_SECRET_KEY_SIZE);
let ml_kem_ct = pool.acquire_ml_kem_ciphertext().unwrap();
assert_eq!(ml_kem_ct.as_ref().0.len(), ML_KEM_768_CIPHERTEXT_SIZE);
let ml_dsa_pk = pool.acquire_ml_dsa_public_key().unwrap();
assert_eq!(ml_dsa_pk.as_ref().0.len(), ML_DSA_65_PUBLIC_KEY_SIZE);
let ml_dsa_sk = pool.acquire_ml_dsa_secret_key().unwrap();
assert_eq!(ml_dsa_sk.as_ref().0.len(), ML_DSA_65_SECRET_KEY_SIZE);
let ml_dsa_sig = pool.acquire_ml_dsa_signature().unwrap();
assert_eq!(ml_dsa_sig.as_ref().0.len(), ML_DSA_65_SIGNATURE_SIZE);
}
#[test]
fn test_hit_rate_calculation() {
let pool = PqcMemoryPool::new(PoolConfig {
initial_size: 2,
max_size: 10,
growth_increment: 1,
acquire_timeout: Duration::from_secs(1),
});
let _g1 = pool.acquire_ml_kem_public_key().unwrap();
let _g2 = pool.acquire_ml_kem_public_key().unwrap();
let _g3 = pool.acquire_ml_kem_public_key().unwrap();
let hit_rate = pool.stats().hit_rate();
assert!(
(hit_rate - 66.7).abs() < 0.1,
"Hit rate should be approximately 66.7%"
);
}
}