use std::sync::Arc;
use aes_gcm::{Aes256Gcm, Key, KeyInit};
use aes_gcm::aead::{Aead, AeadCore, OsRng, Nonce};
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::info;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
pub struct HardwareEncryptionEngine {
cipher: Arc<RwLock<Aes256Gcm>>,
ring_cipher: Arc<RwLock<Aes256Gcm>>,
metrics: Arc<RwLock<EncryptionMetrics>>,
config: EncryptionConfig,
hardware_caps: HardwareCapabilities,
}
#[derive(Debug, Clone)]
pub struct EncryptionConfig {
pub key_derivation_iterations: u32,
pub enable_hardware_acceleration: bool,
pub batch_size: usize,
pub cache_keys: bool,
pub enable_parallel: bool,
}
impl Default for EncryptionConfig {
fn default() -> Self {
Self {
key_derivation_iterations: 100_000,
enable_hardware_acceleration: true,
batch_size: 64,
cache_keys: true,
enable_parallel: true,
}
}
}
#[derive(Debug, Clone)]
pub struct HardwareCapabilities {
pub aes_ni: bool,
pub avx2: bool,
pub avx512: bool,
pub bmi2: bool,
pub fma: bool,
}
#[derive(Debug, Default, Clone)]
pub struct EncryptionMetrics {
pub total_encryptions: u64,
pub total_decryptions: u64,
pub avg_encryption_time_us: f64,
pub avg_decryption_time_us: f64,
pub throughput_mbps: f64,
pub hardware_acceleration_rate: f64,
pub cache_hit_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub ciphertext: Vec<u8>,
pub tag: Vec<u8>,
pub nonce: Vec<u8>,
pub algorithm: String,
pub timestamp: u64,
pub hardware_accelerated: bool,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct KeyCacheEntry {
key: Vec<u8>,
created_at: Instant,
last_used: Instant,
usage_count: u64,
}
impl HardwareEncryptionEngine {
pub fn new(config: EncryptionConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let hardware_caps = Self::detect_hardware_capabilities();
info!("Hardware capabilities detected: {:?}", hardware_caps);
let master_key = Self::generate_master_key(&config)?;
let aes_key = &Key::<Aes256Gcm>::from_slice(&master_key);
let cipher = Aes256Gcm::new(aes_key);
let ring_cipher = cipher.clone();
Ok(Self {
cipher: Arc::new(RwLock::new(cipher)),
ring_cipher: Arc::new(RwLock::new(ring_cipher)),
metrics: Arc::new(RwLock::new(EncryptionMetrics::default())),
config,
hardware_caps,
})
}
pub async fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, Box<dyn std::error::Error + Send + Sync>> {
let start_time = Instant::now();
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let (ciphertext, tag, hardware_accelerated) = if self.config.enable_hardware_acceleration && self.hardware_caps.aes_ni {
self.encrypt_hardware_accelerated(plaintext, &nonce).await?
} else {
self.encrypt_software(plaintext, &nonce).await?
};
let encryption_time = start_time.elapsed();
{
let mut metrics = self.metrics.write().await;
metrics.total_encryptions += 1;
metrics.avg_encryption_time_us = (metrics.avg_encryption_time_us * (metrics.total_encryptions - 1) as f64 + encryption_time.as_micros() as f64) / metrics.total_encryptions as f64;
if hardware_accelerated {
metrics.hardware_acceleration_rate = (metrics.hardware_acceleration_rate * (metrics.total_encryptions - 1) as f64 + 1.0) / metrics.total_encryptions as f64;
}
}
Ok(EncryptedData {
ciphertext,
tag,
nonce: nonce.to_vec(),
algorithm: "AES-256-GCM".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
hardware_accelerated,
})
}
pub async fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let start_time = Instant::now();
let nonce = &Nonce::<Aes256Gcm>::from_slice(&encrypted_data.nonce);
let plaintext = if encrypted_data.hardware_accelerated && self.config.enable_hardware_acceleration && self.hardware_caps.aes_ni {
self.decrypt_hardware_accelerated(&encrypted_data.ciphertext, &encrypted_data.tag, nonce).await?
} else {
self.decrypt_software(&encrypted_data.ciphertext, &encrypted_data.tag, nonce).await?
};
let decryption_time = start_time.elapsed();
{
let mut metrics = self.metrics.write().await;
metrics.total_decryptions += 1;
metrics.avg_decryption_time_us = (metrics.avg_decryption_time_us * (metrics.total_decryptions - 1) as f64 + decryption_time.as_micros() as f64) / metrics.total_decryptions as f64;
}
Ok(plaintext)
}
pub async fn encrypt_batch(&self, plaintexts: &[Vec<u8>]) -> Result<Vec<EncryptedData>, Box<dyn std::error::Error + Send + Sync>> {
if !self.config.enable_parallel {
let mut results = Vec::with_capacity(plaintexts.len());
for plaintext in plaintexts {
results.push(self.encrypt(plaintext).await?);
}
return Ok(results);
}
let mut handles = Vec::new();
for plaintext in plaintexts {
let engine = self.clone();
let plaintext = plaintext.clone();
let handle = tokio::spawn(async move {
engine.encrypt(&plaintext).await
});
handles.push(handle);
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
match handle.await {
Ok(result) => results.push(result?),
Err(e) => return Err(Box::new(e)),
}
}
Ok(results)
}
pub async fn get_metrics(&self) -> EncryptionMetrics {
self.metrics.read().await.clone()
}
pub async fn reset_metrics(&self) {
let mut metrics = self.metrics.write().await;
*metrics = EncryptionMetrics::default();
}
fn detect_hardware_capabilities() -> HardwareCapabilities {
HardwareCapabilities {
aes_ni: is_x86_feature_detected!("aes"),
avx2: is_x86_feature_detected!("avx2"),
avx512: is_x86_feature_detected!("avx512f"),
bmi2: is_x86_feature_detected!("bmi2"),
fma: is_x86_feature_detected!("fma"),
}
}
fn generate_master_key(_config: &EncryptionConfig) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use rand::RngCore;
let mut rng = rand::thread_rng();
let mut key = vec![0u8; 32]; rng.fill_bytes(&mut key);
Ok(key)
}
async fn encrypt_hardware_accelerated(&self, plaintext: &[u8], nonce: &Nonce<Aes256Gcm>) -> Result<(Vec<u8>, Vec<u8>, bool), Box<dyn std::error::Error + Send + Sync>> {
let cipher = self.ring_cipher.read().await;
let _aad: &[u8; 0] = &[];
let mut buffer = plaintext.to_vec();
buffer.resize(buffer.len() + 16, 0);
let nonce_bytes = &Nonce::<Aes256Gcm>::from_slice(nonce.as_slice());
let ciphertext = cipher.encrypt(nonce_bytes, plaintext)
.map_err(|e| format!("Hardware encryption failed: {}", e))?;
let tag_len = 16;
let tag = ciphertext[ciphertext.len() - tag_len..].to_vec();
let ciphertext = ciphertext[..ciphertext.len() - tag_len].to_vec();
Ok((ciphertext, tag, true))
}
async fn encrypt_software(&self, plaintext: &[u8], nonce: &Nonce<Aes256Gcm>) -> Result<(Vec<u8>, Vec<u8>, bool), Box<dyn std::error::Error + Send + Sync>> {
let cipher = self.cipher.read().await;
let ciphertext = cipher.encrypt(nonce, plaintext)
.map_err(|e| format!("Software encryption failed: {}", e))?;
let tag_start = ciphertext.len() - 16;
let ciphertext_bytes = ciphertext[..tag_start].to_vec();
let tag = ciphertext[tag_start..].to_vec();
Ok((ciphertext_bytes, tag, false))
}
async fn decrypt_hardware_accelerated(&self, ciphertext: &[u8], tag: &[u8], nonce: &Nonce<Aes256Gcm>) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let cipher = self.ring_cipher.read().await;
let mut buffer = Vec::with_capacity(ciphertext.len() + tag.len());
buffer.extend_from_slice(ciphertext);
buffer.extend_from_slice(tag);
let _aad: &[u8; 0] = &[];
let nonce_bytes = &Nonce::<Aes256Gcm>::from_slice(nonce.as_slice());
let mut encrypted_data = Vec::with_capacity(ciphertext.len() + tag.len());
encrypted_data.extend_from_slice(ciphertext);
encrypted_data.extend_from_slice(tag);
let plaintext = cipher.decrypt(nonce_bytes, &*encrypted_data)
.map_err(|e| format!("Hardware decryption failed: {}", e))?;
Ok(plaintext)
}
async fn decrypt_software(&self, ciphertext: &[u8], tag: &[u8], nonce: &Nonce<Aes256Gcm>) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let cipher = self.cipher.read().await;
let mut encrypted_data = Vec::with_capacity(ciphertext.len() + tag.len());
encrypted_data.extend_from_slice(ciphertext);
encrypted_data.extend_from_slice(tag);
let plaintext = cipher.decrypt(nonce, &*encrypted_data)
.map_err(|e| format!("Software decryption failed: {}", e))?;
Ok(plaintext)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_encryption_decryption() {
let config = EncryptionConfig::default();
let engine = HardwareEncryptionEngine::new(config).unwrap();
let plaintext = b"Hello, Solana Account Cleaner!";
let encrypted = engine.encrypt(plaintext).await.unwrap();
let decrypted = engine.decrypt(&encrypted).await.unwrap();
assert_eq!(plaintext.to_vec(), decrypted);
}
#[tokio::test]
async fn test_batch_encryption() {
let config = EncryptionConfig::default();
let engine = HardwareEncryptionEngine::new(config).unwrap();
let plaintexts = vec![
b"Data 1".to_vec(),
b"Data 2".to_vec(),
b"Data 3".to_vec(),
];
let encrypted = engine.encrypt_batch(&plaintexts).await.unwrap();
assert_eq!(encrypted.len(), plaintexts.len());
for (i, enc_data) in encrypted.iter().enumerate() {
let decrypted = engine.decrypt(enc_data).await.unwrap();
assert_eq!(decrypted, plaintexts[i]);
}
}
#[test]
fn test_hardware_capabilities_detection() {
let caps = HardwareEncryptionEngine::detect_hardware_capabilities();
println!("Hardware capabilities: {:?}", caps);
}
#[tokio::test]
async fn test_metrics_tracking() {
let config = EncryptionConfig::default();
let engine = HardwareEncryptionEngine::new(config).unwrap();
let initial_metrics = engine.get_metrics().await;
assert_eq!(initial_metrics.total_encryptions, 0);
assert_eq!(initial_metrics.total_decryptions, 0);
let plaintext = b"Test data";
let _encrypted = engine.encrypt(plaintext).await.unwrap();
let metrics = engine.get_metrics().await;
assert_eq!(metrics.total_encryptions, 1);
assert!(metrics.avg_encryption_time_us > 0.0);
}
}