use aes_gcm::{
aead::{Aead, KeyInit as AeadKeyInit},
Aes256Gcm, Nonce,
};
use hkdf::Hkdf;
use sha3::Sha3_256;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use ml_kem::{MlKem1024, Seed, SharedKey, Kem, kem::{Decapsulate, Encapsulate, KeyExport}};
use ml_kem::{DecapsulationKey1024, EncapsulationKey1024};
use serde::{Deserialize, Serialize};
use std::sync::Mutex;
use zeroize::Zeroize;
const DOMAIN: &[u8] = b"Abir-Guard-Hybrid-2026";
const HANDSHAKE_TIMEOUT_MS: u128 = 200;
#[derive(Debug)]
pub struct SecurityWatchdogError {
pub elapsed_ms: u128,
pub threshold_ms: u128,
}
impl std::fmt::Display for SecurityWatchdogError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Latency anomaly: {}ms (threshold: {}ms). Potential side-channel attack.",
self.elapsed_ms, self.threshold_ms)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyPair {
pub public_key: String,
pub secret_key: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ciphertext {
pub nonce: String,
pub ciphertext: String,
pub key_id: String,
}
pub struct MlKem1024Kem {
_private: (),
}
impl MlKem1024Kem {
pub fn new() -> Self {
Self { _private: () }
}
pub fn is_available() -> bool {
true
}
pub fn keygen(&self) -> (KeyPair, Vec<u8>) {
let (dk, ek) = MlKem1024::generate_keypair();
let ek_bytes = ek.to_bytes();
let dk_seed: Seed = dk.to_seed().expect("key from seed");
let kp = KeyPair {
public_key: BASE64.encode(ek_bytes.as_slice()),
secret_key: BASE64.encode(dk_seed.as_slice()),
};
(kp, dk_seed.to_vec())
}
pub fn encapsulate(&self, public_key_b64: &str) -> Result<(Vec<u8>, Vec<u8>), String> {
let start = std::time::Instant::now();
let pk_bytes = BASE64.decode(public_key_b64).map_err(|e| e.to_string())?;
let ek = EncapsulationKey1024::new(pk_bytes.as_slice().try_into().map_err(|_| "Invalid public key length")?).map_err(|_| "Invalid public key")?;
let (ct, ss) = ek.encapsulate();
let ct_bytes: & [u8] = ct.as_ref();
let ss_bytes: &[u8] = ss.as_ref();
let elapsed = start.elapsed().as_millis();
if elapsed > HANDSHAKE_TIMEOUT_MS {
return Err(format!(
"Security watchdog: encapsulation took {}ms (threshold: {}ms)",
elapsed, HANDSHAKE_TIMEOUT_MS
));
}
Ok((ct_bytes.to_vec(), ss_bytes.to_vec()))
}
pub fn decapsulate(&self, ciphertext: &[u8], secret_key_b64: &str) -> Result<Vec<u8>, String> {
let start = std::time::Instant::now();
let sk_bytes = BASE64.decode(secret_key_b64).map_err(|e| e.to_string())?;
let seed: Seed = sk_bytes.as_slice().try_into().map_err(|_| "Invalid secret key length")?;
let dk = DecapsulationKey1024::from_seed(seed);
let ct: ml_kem::ml_kem_1024::Ciphertext = ciphertext.try_into().map_err(|_| "Invalid ciphertext length")?;
let ss: SharedKey = dk.decapsulate(&ct);
let ss_bytes: &[u8] = ss.as_ref();
let elapsed = start.elapsed().as_millis();
if elapsed > HANDSHAKE_TIMEOUT_MS {
return Err(format!(
"Security watchdog: decapsulation took {}ms (threshold: {}ms)",
elapsed, HANDSHAKE_TIMEOUT_MS
));
}
Ok(ss_bytes.to_vec())
}
}
impl Default for MlKem1024Kem {
fn default() -> Self {
Self::new()
}
}
pub struct HybridEncryptor {
#[allow(dead_code)]
key_size: usize,
ml_kem: MlKem1024Kem,
}
impl HybridEncryptor {
pub fn new() -> Self {
Self {
key_size: 32,
ml_kem: MlKem1024Kem::new(),
}
}
pub fn generate_keypair(&self) -> (KeyPair, Vec<u8>) {
self.ml_kem.keygen()
}
pub fn encrypt(&self, plaintext: &[u8], secret: &[u8]) -> Ciphertext {
let start = std::time::Instant::now();
let hkdf = Hkdf::<Sha3_256>::new(Some(DOMAIN), secret);
let mut key = [0u8; 32];
hkdf.expand(b"aes-key", &mut key).expect("HKDF expand");
let cipher = Aes256Gcm::new_from_slice(&key).expect("Valid key");
let mut nonce_bytes = [0u8; 12];
getrandom::fill(&mut nonce_bytes).expect("Failed to get random nonce");
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher.encrypt(nonce, plaintext).expect("Encrypt");
key.zeroize();
let elapsed = start.elapsed().as_millis();
if elapsed > HANDSHAKE_TIMEOUT_MS {
panic!(
"Security watchdog: encryption took {}ms (threshold: {}ms). Potential side-channel attack.",
elapsed, HANDSHAKE_TIMEOUT_MS
);
}
Ciphertext {
nonce: BASE64.encode(&nonce_bytes[..]),
ciphertext: BASE64.encode(ciphertext.as_slice()),
key_id: String::new(),
}
}
pub fn decrypt(&self, ct: &Ciphertext, secret: &[u8]) -> Result<Vec<u8>, String> {
let start = std::time::Instant::now();
let hkdf = Hkdf::<Sha3_256>::new(Some(DOMAIN), secret);
let mut key = [0u8; 32];
hkdf.expand(b"aes-key", &mut key).expect("HKDF expand");
let cipher = Aes256Gcm::new_from_slice(&key).map_err(|e| e.to_string())?;
let nonce = BASE64.decode(&ct.nonce).map_err(|e| e.to_string())?;
let nonce = Nonce::from_slice(&nonce);
let ciphertext = BASE64.decode(&ct.ciphertext).map_err(|e| e.to_string())?;
let plaintext = cipher.decrypt(nonce, ciphertext.as_ref()).map_err(|e| e.to_string())?;
key.zeroize();
let elapsed = start.elapsed().as_millis();
if elapsed > HANDSHAKE_TIMEOUT_MS {
return Err(format!(
"Security watchdog: decryption took {}ms (threshold: {}ms)",
elapsed, HANDSHAKE_TIMEOUT_MS
));
}
Ok(plaintext)
}
pub fn ml_kem(&self) -> &MlKem1024Kem {
&self.ml_kem
}
}
impl Default for HybridEncryptor {
fn default() -> Self {
Self::new()
}
}
pub struct Vault {
encryptor: HybridEncryptor,
keys: Mutex<Vec<(String, KeyPair, Vec<u8>)>>,
}
impl Vault {
pub fn new() -> Self {
Self {
encryptor: HybridEncryptor::new(),
keys: Mutex::new(Vec::new()),
}
}
pub fn generate_keypair(&self, key_id: &str) -> (String, String) {
let (kp, sk) = self.encryptor.generate_keypair();
let mut keys = self.keys.lock().unwrap();
keys.push((key_id.to_string(), kp.clone(), sk));
(kp.public_key, kp.secret_key)
}
pub fn store(&self, key_id: &[u8], plaintext: &[u8]) -> Result<Ciphertext, String> {
let key_id_str = String::from_utf8(key_id.to_vec()).map_err(|e| e.to_string())?;
let sk;
{
let mut keys = self.keys.lock().unwrap();
match keys.iter()
.find(|(id, _, _)| id.as_bytes() == key_id)
{
Some((_, _, stored_sk)) => sk = stored_sk.clone(),
None => {
let (kp, generated_sk) = self.encryptor.generate_keypair();
keys.push((key_id_str.clone(), kp, generated_sk.clone()));
sk = generated_sk;
}
}
}
Ok(self.encryptor.encrypt(plaintext, &sk))
}
pub fn retrieve(&self, key_id: &[u8], ct: &Ciphertext) -> Result<Vec<u8>, String> {
let keys = self.keys.lock().unwrap();
let (_, _, sk) = keys.iter()
.find(|(id, _, _)| id.as_bytes() == key_id)
.ok_or("Key not found")?;
self.encryptor.decrypt(ct, sk)
}
pub fn list_keypairs(&self) -> Vec<String> {
let keys = self.keys.lock().unwrap();
keys.iter().map(|(id, _, _)| id.clone()).collect()
}
pub fn remove_keypair(&self, key_id: &str) {
let mut keys = self.keys.lock().unwrap();
keys.retain(|(id, _, _)| id != key_id);
}
pub fn export_keys(&self) -> Vec<(String, String, String)> {
let keys = self.keys.lock().unwrap();
keys.iter()
.map(|(id, kp, _sk)| (id.clone(), kp.public_key.clone(), kp.secret_key.clone()))
.collect()
}
pub fn import_key(&self, key_id: &str, public_key: &str, secret_key_b64: &str) -> Result<(), String> {
let sk = base64_decode(secret_key_b64)?;
let kp = KeyPair {
public_key: public_key.to_string(),
secret_key: secret_key_b64.to_string(),
};
let mut keys = self.keys.lock().unwrap();
keys.retain(|(id, _, _)| id != key_id);
keys.push((key_id.to_string(), kp, sk));
Ok(())
}
pub fn ml_kem_roundtrip(&self) -> Result<bool, String> {
let kem = self.encryptor.ml_kem();
let (kp, _sk) = kem.keygen();
let (ct, ss1) = kem.encapsulate(&kp.public_key)?;
let ss2 = kem.decapsulate(&ct, &kp.secret_key)?;
Ok(ss1 == ss2)
}
}
fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
use base64::Engine;
BASE64.decode(s).map_err(|e| e.to_string())
}
impl Default for Vault {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vault() {
let vault = Vault::new();
let (_, _) = vault.generate_keypair("test");
let ct = vault.store(b"test", b"secret").unwrap();
let plain = vault.retrieve(b"test", &ct).unwrap();
assert_eq!(plain, b"secret");
}
#[test]
fn test_auto_generate_key() {
let vault = Vault::new();
let ct = vault.store(b"auto-key", b"auto-secret").unwrap();
let plain = vault.retrieve(b"auto-key", &ct).unwrap();
assert_eq!(plain, b"auto-secret");
}
#[test]
fn test_key_zeroization() {
let encryptor = HybridEncryptor::new();
let (_, secret) = encryptor.generate_keypair();
let ct = encryptor.encrypt(b"test data", &secret);
let plain = encryptor.decrypt(&ct, &secret).unwrap();
assert_eq!(plain, b"test data");
}
#[test]
fn test_ml_kem_1024_roundtrip() {
let kem = MlKem1024Kem::new();
assert!(MlKem1024Kem::is_available());
let (kp, _sk) = kem.keygen();
let (ct, ss1) = kem.encapsulate(&kp.public_key).unwrap();
let ss2 = kem.decapsulate(&ct, &kp.secret_key).unwrap();
assert_eq!(ss1, ss2, "ML-KEM-1024 shared secrets must match");
}
#[test]
fn test_ml_kem_key_sizes() {
let kem = MlKem1024Kem::new();
let (kp, _sk) = kem.keygen();
assert_eq!(BASE64.decode(&kp.public_key).unwrap().len(), 1568, "PK size mismatch");
assert_eq!(BASE64.decode(&kp.secret_key).unwrap().len(), 64, "SK seed size mismatch");
}
#[test]
fn test_vault_ml_kem_roundtrip() {
let vault = Vault::new();
assert!(vault.ml_kem_roundtrip().unwrap());
}
}