use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Key, Nonce as AesNonce};
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::FsError;
use crate::arch;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncryptAlgo {
ChaCha20Poly1305,
Aes256Gcm,
}
#[derive(Debug, Clone)]
pub struct CpuFeatures {
pub aes_ni: bool,
pub pclmulqdq: bool,
pub avx2: bool,
pub avx512: bool,
}
impl CpuFeatures {
pub fn detect() -> Self {
Self {
aes_ni: true, pclmulqdq: true,
avx2: true,
avx512: false, }
}
pub fn supports_algo(&self, algo: EncryptAlgo) -> bool {
match algo {
EncryptAlgo::ChaCha20Poly1305 => true, EncryptAlgo::Aes256Gcm => self.aes_ni && self.pclmulqdq,
}
}
pub fn is_accelerated(&self, algo: EncryptAlgo) -> bool {
match algo {
EncryptAlgo::ChaCha20Poly1305 => self.avx2, EncryptAlgo::Aes256Gcm => self.aes_ni && self.pclmulqdq,
}
}
}
#[derive(Debug, Clone, Zeroize, ZeroizeOnDrop)]
pub struct EncryptionKey {
pub key: [u8; 32],
#[zeroize(skip)]
pub algo: EncryptAlgo,
#[zeroize(skip)]
pub key_id: u64,
}
impl EncryptionKey {
pub fn new(key: [u8; 32], algo: EncryptAlgo, key_id: u64) -> Self {
Self { key, algo, key_id }
}
pub fn generate(
algo: EncryptAlgo,
key_id: u64,
) -> Result<Self, crate::crypto::random::RandomError> {
let key = crate::crypto::random::generate_key()?;
Ok(Self::new(key, algo, key_id))
}
}
#[derive(Debug, Clone, Copy)]
pub struct Nonce {
pub bytes: [u8; 12],
}
impl Nonce {
pub fn new(bytes: [u8; 12]) -> Self {
Self { bytes }
}
pub fn from_counter(counter: u64) -> Self {
let mut bytes = [0u8; 12];
bytes[0..8].copy_from_slice(&counter.to_le_bytes());
Self { bytes }
}
pub fn increment(&mut self) {
let counter_bytes: [u8; 8] = [
self.bytes[0],
self.bytes[1],
self.bytes[2],
self.bytes[3],
self.bytes[4],
self.bytes[5],
self.bytes[6],
self.bytes[7],
];
let mut counter = u64::from_le_bytes(counter_bytes);
counter += 1;
self.bytes[0..8].copy_from_slice(&counter.to_le_bytes());
}
}
#[derive(Debug, Clone)]
pub struct EncryptOp {
pub op_id: u64,
pub algo: EncryptAlgo,
pub key_id: u64,
pub input_size: u64,
pub hw_accelerated: bool,
pub submitted: u64,
}
#[derive(Debug, Clone)]
pub struct EncryptResult {
pub op_id: u64,
pub output_size: u64,
pub nonce: Nonce,
pub exec_time_ns: u64,
pub hw_accelerated: bool,
}
impl EncryptResult {
pub fn throughput_gbps(&self, input_size: u64) -> f32 {
if self.exec_time_ns == 0 {
return 0.0;
}
(input_size as f64 / (self.exec_time_ns as f64 / 1e9) / 1e9) as f32
}
pub fn cycles_per_byte(&self, input_size: u64) -> f32 {
if input_size == 0 {
return 0.0;
}
let cpu_freq_ghz = 3.0; let cycles = (self.exec_time_ns as f64 / 1e9) * cpu_freq_ghz * 1e9;
cycles as f32 / input_size as f32
}
}
#[derive(Debug, Clone, Default)]
pub struct EncryptStats {
pub total_ops: u64,
pub hw_accel_ops: u64,
pub sw_ops: u64,
pub total_bytes: u64,
pub total_time_ns: u64,
}
impl EncryptStats {
pub fn hw_accel_ratio(&self) -> f32 {
if self.total_ops == 0 {
return 0.0;
}
self.hw_accel_ops as f32 / self.total_ops as f32
}
pub fn avg_throughput_gbps(&self) -> f32 {
if self.total_time_ns == 0 {
return 0.0;
}
(self.total_bytes as f64 / (self.total_time_ns as f64 / 1e9) / 1e9) as f32
}
pub fn avg_cycles_per_byte(&self) -> f32 {
if self.total_bytes == 0 {
return 0.0;
}
let cpu_freq_ghz = 3.0;
let total_cycles = (self.total_time_ns as f64 / 1e9) * cpu_freq_ghz * 1e9;
total_cycles as f32 / self.total_bytes as f32
}
}
pub struct InlineEncryptManager {
features: CpuFeatures,
keys: BTreeMap<u64, EncryptionKey>,
nonce_counter: u64,
next_op_id: u64,
stats: EncryptStats,
}
impl Default for InlineEncryptManager {
fn default() -> Self {
Self::new()
}
}
impl InlineEncryptManager {
pub fn new() -> Self {
Self {
features: CpuFeatures::detect(),
keys: BTreeMap::new(),
nonce_counter: 1,
next_op_id: 1,
stats: EncryptStats::default(),
}
}
pub fn add_key(&mut self, key: EncryptionKey) -> Result<(), &'static str> {
if !self.features.supports_algo(key.algo) {
return Err("Algorithm not supported by CPU");
}
self.keys.insert(key.key_id, key);
Ok(())
}
pub fn remove_key(&mut self, key_id: u64) -> bool {
self.keys.remove(&key_id).is_some()
}
pub fn encrypt(
&mut self,
key_id: u64,
input_size: u64,
timestamp: u64,
) -> Result<EncryptResult, &'static str> {
let key = self.keys.get(&key_id).ok_or("Key not found")?;
let hw_accelerated = self.features.is_accelerated(key.algo);
let nonce = Nonce::from_counter(self.nonce_counter);
self.nonce_counter += 1;
let exec_time_ns = if hw_accelerated {
match key.algo {
EncryptAlgo::Aes256Gcm => (input_size as f64 * 0.63 / 3.0) as u64, EncryptAlgo::ChaCha20Poly1305 => (input_size as f64 * 1.2 / 3.0) as u64, }
} else {
(input_size as f64 * 7.0 / 3.0) as u64
};
let op_id = self.next_op_id;
self.next_op_id += 1;
self.stats.total_ops += 1;
if hw_accelerated {
self.stats.hw_accel_ops += 1;
} else {
self.stats.sw_ops += 1;
}
self.stats.total_bytes += input_size;
self.stats.total_time_ns += exec_time_ns;
Ok(EncryptResult {
op_id,
output_size: input_size + 16, nonce,
exec_time_ns,
hw_accelerated,
})
}
pub fn decrypt(
&mut self,
key_id: u64,
input_size: u64,
_nonce: Nonce,
timestamp: u64,
) -> Result<EncryptResult, &'static str> {
self.encrypt(key_id, input_size - 16, timestamp) }
pub fn stats(&self) -> EncryptStats {
self.stats.clone()
}
pub fn features(&self) -> &CpuFeatures {
&self.features
}
pub fn get_key(&self, key_id: u64) -> Option<&EncryptionKey> {
self.keys.get(&key_id)
}
}
lazy_static! {
static ref INLINE_ENCRYPT_ENGINE: Mutex<InlineEncryptManager> =
Mutex::new(InlineEncryptManager::new());
}
pub struct InlineEncryptEngine;
impl InlineEncryptEngine {
pub fn add_key(key: EncryptionKey) -> Result<(), &'static str> {
let mut engine = INLINE_ENCRYPT_ENGINE.lock();
engine.add_key(key)
}
pub fn remove_key(key_id: u64) -> bool {
let mut engine = INLINE_ENCRYPT_ENGINE.lock();
engine.remove_key(key_id)
}
pub fn encrypt(
key_id: u64,
input_size: u64,
timestamp: u64,
) -> Result<EncryptResult, &'static str> {
let mut engine = INLINE_ENCRYPT_ENGINE.lock();
engine.encrypt(key_id, input_size, timestamp)
}
pub fn decrypt(
key_id: u64,
input_size: u64,
nonce: Nonce,
timestamp: u64,
) -> Result<EncryptResult, &'static str> {
let mut engine = INLINE_ENCRYPT_ENGINE.lock();
engine.decrypt(key_id, input_size, nonce, timestamp)
}
pub fn stats() -> EncryptStats {
let engine = INLINE_ENCRYPT_ENGINE.lock();
engine.stats()
}
pub fn features() -> CpuFeatures {
let engine = INLINE_ENCRYPT_ENGINE.lock();
engine.features().clone()
}
}
pub const AES_KEY_SIZE: usize = 32;
pub const AES_NONCE_SIZE: usize = 12;
pub const AES_TAG_SIZE: usize = 16;
#[inline]
pub fn has_aesni() -> bool {
arch::has_aesni()
}
pub fn aesni_encrypt(data: &[u8], key: &[u8]) -> Result<Vec<u8>, FsError> {
if key.len() != AES_KEY_SIZE {
return Err(FsError::InvalidArgument {
reason: "AES-256 key must be 32 bytes",
});
}
let cipher_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(cipher_key);
let mut nonce_bytes = [0u8; AES_NONCE_SIZE];
fill_hardware_entropy(&mut nonce_bytes)?;
let nonce = AesNonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, data)
.map_err(|_| FsError::EncryptionFailed)?;
let mut result = Vec::with_capacity(AES_NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn aesni_decrypt(data: &[u8], key: &[u8]) -> Result<Vec<u8>, FsError> {
if key.len() != AES_KEY_SIZE {
return Err(FsError::InvalidArgument {
reason: "AES-256 key must be 32 bytes",
});
}
if data.len() < AES_NONCE_SIZE + AES_TAG_SIZE {
return Err(FsError::InvalidArgument {
reason: "ciphertext too short (minimum 28 bytes)",
});
}
let nonce = AesNonce::from_slice(&data[..AES_NONCE_SIZE]);
let ciphertext = &data[AES_NONCE_SIZE..];
let cipher_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(cipher_key);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|_| FsError::DecryptionFailed)?;
Ok(plaintext)
}
fn fill_hardware_entropy(buf: &mut [u8]) -> Result<(), FsError> {
if arch::has_rdrand() && arch::fill_hardware_entropy(buf).is_ok() {
return Ok(());
}
crate::crypto::random::fill_random(buf).map_err(|_| FsError::EncryptionFailed)
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_cpu_features() {
let features = CpuFeatures::detect();
assert!(features.aes_ni);
assert!(features.pclmulqdq);
assert!(features.supports_algo(EncryptAlgo::Aes256Gcm));
assert!(features.supports_algo(EncryptAlgo::ChaCha20Poly1305));
}
#[test]
fn test_key_management() {
let mut mgr = InlineEncryptManager::new();
let key = EncryptionKey::generate(EncryptAlgo::Aes256Gcm, 1)
.expect("test: operation should succeed");
mgr.add_key(key.clone())
.expect("test: operation should succeed");
assert!(mgr.get_key(1).is_some());
assert!(mgr.remove_key(1));
assert!(mgr.get_key(1).is_none());
}
#[test]
fn test_nonce_generation() {
let nonce1 = Nonce::from_counter(1);
let nonce2 = Nonce::from_counter(2);
assert_ne!(nonce1.bytes, nonce2.bytes);
let mut nonce = Nonce::from_counter(100);
nonce.increment();
let expected = Nonce::from_counter(101);
assert_eq!(nonce.bytes, expected.bytes);
}
#[test]
fn test_encrypt_aes_gcm() {
let mut mgr = InlineEncryptManager::new();
let key = EncryptionKey::generate(EncryptAlgo::Aes256Gcm, 1)
.expect("test: operation should succeed");
mgr.add_key(key).expect("test: operation should succeed");
let result = mgr
.encrypt(1, 1_000_000, 0)
.expect("test: operation should succeed");
assert_eq!(result.output_size, 1_000_016); assert!(result.hw_accelerated); assert!(result.exec_time_ns > 0);
}
#[test]
fn test_encrypt_chacha20() {
let mut mgr = InlineEncryptManager::new();
let key = EncryptionKey::generate(EncryptAlgo::ChaCha20Poly1305, 1)
.expect("test: operation should succeed");
mgr.add_key(key).expect("test: operation should succeed");
let result = mgr
.encrypt(1, 1_000_000, 0)
.expect("test: operation should succeed");
assert_eq!(result.output_size, 1_000_016);
assert!(result.hw_accelerated); }
#[test]
fn test_decrypt() {
let mut mgr = InlineEncryptManager::new();
let key = EncryptionKey::generate(EncryptAlgo::Aes256Gcm, 1)
.expect("test: operation should succeed");
mgr.add_key(key).expect("test: operation should succeed");
let enc_result = mgr
.encrypt(1, 1_000_000, 0)
.expect("test: operation should succeed");
let dec_result = mgr
.decrypt(1, enc_result.output_size, enc_result.nonce, 0)
.expect("test: operation should succeed");
assert_eq!(dec_result.output_size, enc_result.output_size);
}
#[test]
fn test_throughput_calculation() {
let result = EncryptResult {
op_id: 1,
output_size: 1_000_016,
nonce: Nonce::from_counter(1),
exec_time_ns: 210_000, hw_accelerated: true,
};
let throughput = result.throughput_gbps(1_000_000);
assert!(throughput > 4.0 && throughput < 6.0);
}
#[test]
fn test_cycles_per_byte() {
let result = EncryptResult {
op_id: 1,
output_size: 1_000_016,
nonce: Nonce::from_counter(1),
exec_time_ns: 210_000,
hw_accelerated: true,
};
let cpb = result.cycles_per_byte(1_000_000);
assert!(cpb > 0.5 && cpb < 1.0);
}
#[test]
fn test_statistics() {
let mut mgr = InlineEncryptManager::new();
let key = EncryptionKey::generate(EncryptAlgo::Aes256Gcm, 1)
.expect("test: operation should succeed");
mgr.add_key(key).expect("test: operation should succeed");
for _ in 0..10 {
mgr.encrypt(1, 1_000_000, 0)
.expect("test: operation should succeed");
}
let stats = mgr.stats();
assert_eq!(stats.total_ops, 10);
assert_eq!(stats.hw_accel_ops, 10);
assert_eq!(stats.total_bytes, 10_000_000);
assert!(stats.avg_throughput_gbps() > 4.0);
}
#[test]
fn test_hw_accel_ratio() {
let mut stats = EncryptStats::default();
stats.total_ops = 100;
stats.hw_accel_ops = 80;
stats.sw_ops = 20;
assert_eq!(stats.hw_accel_ratio(), 0.8);
}
#[test]
fn test_key_not_found() {
let mut mgr = InlineEncryptManager::new();
let result = mgr.encrypt(999, 1000, 0);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Key not found");
}
#[test]
fn test_aesni_encrypt_actually_encrypts() {
let plaintext = b"secret data that must be encrypted with AES-NI";
let key = [0u8; 32];
let ciphertext = aesni_encrypt(plaintext, &key).unwrap();
assert_ne!(
&ciphertext[AES_NONCE_SIZE..],
plaintext.as_slice(),
"AES encryption must change data!"
);
assert_eq!(
ciphertext.len(),
AES_NONCE_SIZE + plaintext.len() + AES_TAG_SIZE,
"Ciphertext should be plaintext + 28 bytes overhead"
);
}
#[test]
fn test_aesni_encrypt_produces_different_ciphertext_each_time() {
let plaintext = b"same plaintext";
let key = [0x42u8; 32];
let ciphertext1 = aesni_encrypt(plaintext, &key).unwrap();
let ciphertext2 = aesni_encrypt(plaintext, &key).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Same plaintext should produce different ciphertext due to random nonce"
);
}
#[test]
fn test_aesni_decrypt_reverses_encrypt() {
let plaintext = b"secret data to round-trip through AES-GCM";
let key = [0x42u8; 32];
let ciphertext = aesni_encrypt(plaintext, &key).unwrap();
let decrypted = aesni_decrypt(&ciphertext, &key).unwrap();
assert_eq!(
decrypted.as_slice(),
plaintext.as_slice(),
"AES decryption must recover original plaintext"
);
}
#[test]
fn test_aesni_decrypt_with_wrong_key_fails() {
let plaintext = b"secret";
let key1 = [0x42u8; 32];
let key2 = [0x43u8; 32];
let ciphertext = aesni_encrypt(plaintext, &key1).unwrap();
assert!(
matches!(
aesni_decrypt(&ciphertext, &key2),
Err(FsError::DecryptionFailed)
),
"AES decryption with wrong key must fail"
);
}
#[test]
fn test_aesni_decrypt_tampered_ciphertext_fails() {
let plaintext = b"secret";
let key = [0u8; 32];
let mut ciphertext = aesni_encrypt(plaintext, &key).unwrap();
ciphertext[AES_NONCE_SIZE + 1] ^= 0xFF;
assert!(
matches!(
aesni_decrypt(&ciphertext, &key),
Err(FsError::DecryptionFailed)
),
"Tampered AES ciphertext must fail authentication"
);
}
#[test]
fn test_aesni_invalid_key_length() {
let plaintext = b"test";
assert!(matches!(
aesni_encrypt(plaintext, &[0u8; 16]),
Err(FsError::InvalidArgument { .. })
));
assert!(matches!(
aesni_encrypt(plaintext, &[0u8; 64]),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_aesni_ciphertext_too_short() {
let key = [0u8; 32];
assert!(matches!(
aesni_decrypt(&[0u8; 27], &key),
Err(FsError::InvalidArgument { .. })
));
assert!(matches!(
aesni_decrypt(&[], &key),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_aesni_encrypt_decrypt_empty_plaintext() {
let plaintext = b"";
let key = [0x42u8; 32];
let ciphertext = aesni_encrypt(plaintext, &key).unwrap();
assert_eq!(ciphertext.len(), AES_NONCE_SIZE + AES_TAG_SIZE);
let decrypted = aesni_decrypt(&ciphertext, &key).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_aesni_encrypt_decrypt_large_data() {
let plaintext = alloc::vec![0xAB_u8; 1024 * 1024]; let key = [0x42u8; 32];
let ciphertext = aesni_encrypt(&plaintext, &key).unwrap();
let decrypted = aesni_decrypt(&ciphertext, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_has_aesni_returns_bool() {
let result = has_aesni();
let _ = result;
}
}