use crate::errors::{BottleError, Result};
use crate::tpm::ECDHHandler;
use p256::ecdh::EphemeralSecret;
use p256::{PublicKey, SecretKey};
use rand::{CryptoRng, RngCore};
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
#[cfg(feature = "ml-kem")]
use hybrid_array::{
sizes::{U1088, U1568},
Array,
};
#[cfg(feature = "ml-kem")]
use ml_kem::{kem::Kem, EncodedSizeUser, KemCore, MlKem1024Params, MlKem768Params};
#[cfg(feature = "ml-kem")]
use zerocopy::AsBytes;
pub fn ecdh_encrypt_p256<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &PublicKey,
) -> Result<Vec<u8>> {
let secret = EphemeralSecret::random(rng);
let shared_secret = secret.diffie_hellman(public_key);
let shared_bytes = shared_secret.raw_secret_bytes();
let key = derive_key(shared_bytes.as_ref());
let encrypted = encrypt_aes_gcm(&key, plaintext)?;
let ephemeral_pub = secret.public_key();
let mut result = ephemeral_pub.to_sec1_bytes().to_vec();
result.extend_from_slice(&encrypted);
Ok(result)
}
pub fn ecdh_decrypt_p256(ciphertext: &[u8], private_key: &SecretKey) -> Result<Vec<u8>> {
if ciphertext.len() < 65 {
return Err(BottleError::InvalidFormat);
}
let ephemeral_pub = PublicKey::from_sec1_bytes(&ciphertext[..65])
.map_err(|_| BottleError::Decryption("Invalid ephemeral public key".to_string()))?;
use p256::elliptic_curve::sec1::ToEncodedPoint;
let scalar = private_key.to_nonzero_scalar();
let point = ephemeral_pub.as_affine();
let shared_point = (*point * scalar.as_ref()).to_encoded_point(false);
let shared_bytes = shared_point.x().unwrap().as_ref();
let key = derive_key(shared_bytes);
decrypt_aes_gcm(&key, &ciphertext[65..])
}
pub fn ecdh_encrypt_x25519<R: RngCore>(
rng: &mut R,
plaintext: &[u8],
public_key: &X25519PublicKey,
) -> Result<Vec<u8>> {
let mut secret_bytes = [0u8; 32];
rng.fill_bytes(&mut secret_bytes);
let secret = StaticSecret::from(secret_bytes);
let shared_secret = secret.diffie_hellman(public_key);
let key = derive_key(shared_secret.as_bytes());
let encrypted = encrypt_aes_gcm(&key, plaintext)?;
let ephemeral_pub = X25519PublicKey::from(&secret);
let mut result = ephemeral_pub.as_bytes().to_vec();
result.extend_from_slice(&encrypted);
Ok(result)
}
pub fn ecdh_decrypt_x25519(ciphertext: &[u8], private_key: &[u8; 32]) -> Result<Vec<u8>> {
if ciphertext.len() < 32 {
return Err(BottleError::InvalidFormat);
}
let priv_key = StaticSecret::from(*private_key);
let ephemeral_pub_bytes: [u8; 32] = ciphertext[..32]
.try_into()
.map_err(|_| BottleError::InvalidFormat)?;
let ephemeral_pub = X25519PublicKey::from(ephemeral_pub_bytes);
let shared_secret = priv_key.diffie_hellman(&ephemeral_pub);
let key = derive_key(shared_secret.as_bytes());
decrypt_aes_gcm(&key, &ciphertext[32..])
}
pub trait ECDHEncrypt {
fn encrypt<R: RngCore>(
&self,
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
) -> Result<Vec<u8>>;
}
pub trait ECDHDecrypt {
fn decrypt(&self, ciphertext: &[u8], private_key: &[u8]) -> Result<Vec<u8>>;
}
pub fn ecdh_encrypt<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
) -> Result<Vec<u8>> {
ecdh_encrypt_with_handler(rng, plaintext, public_key, None)
}
pub fn ecdh_encrypt_with_handler<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
_handler: Option<&dyn ECDHHandler>,
) -> Result<Vec<u8>> {
if public_key.len() == 32 {
let pub_key_bytes: [u8; 32] = public_key
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let pub_key = X25519PublicKey::from(pub_key_bytes);
ecdh_encrypt_x25519(rng, plaintext, &pub_key)
} else if public_key.len() == 65 || public_key.len() == 64 {
let pub_key =
PublicKey::from_sec1_bytes(public_key).map_err(|_| BottleError::InvalidKeyType)?;
ecdh_encrypt_p256(rng, plaintext, &pub_key)
} else {
#[cfg(feature = "ml-kem")]
{
if public_key.len() == 1184 {
return mlkem768_encrypt(rng, plaintext, public_key);
} else if public_key.len() == 1568 {
return mlkem1024_encrypt(rng, plaintext, public_key);
}
}
Err(BottleError::InvalidKeyType)
}
}
pub fn ecdh_decrypt(ciphertext: &[u8], private_key: &[u8]) -> Result<Vec<u8>> {
ecdh_decrypt_with_handler(ciphertext, private_key, None)
}
pub fn ecdh_decrypt_with_handler(
ciphertext: &[u8],
private_key: &[u8],
handler: Option<&dyn ECDHHandler>,
) -> Result<Vec<u8>> {
if let Some(h) = handler {
let handler_pub_key = h.public_key()?;
if ciphertext.len() < handler_pub_key.len() {
return Err(BottleError::InvalidFormat);
}
let ephemeral_pub_key = &ciphertext[..handler_pub_key.len()];
let encrypted_data = &ciphertext[handler_pub_key.len()..];
let shared_secret = h.ecdh(ephemeral_pub_key)?;
let key = derive_key(&shared_secret);
return decrypt_aes_gcm(&key, encrypted_data);
}
#[cfg(feature = "ml-kem")]
{
if private_key.len() == 2400 || private_key.len() == 3584 {
if let Ok(result) = mlkem768_decrypt(ciphertext, private_key) {
return Ok(result);
}
}
if private_key.len() == 3168 || private_key.len() == 4736 {
if let Ok(result) = mlkem1024_decrypt(ciphertext, private_key) {
return Ok(result);
}
}
}
if private_key.len() == 32 && ciphertext.len() >= 32 {
let priv_key_bytes: [u8; 32] = match private_key.try_into() {
Ok(bytes) => bytes,
Err(_) => return Err(BottleError::InvalidKeyType),
};
match ecdh_decrypt_x25519(ciphertext, &priv_key_bytes) {
Ok(result) => return Ok(result),
Err(_) => {
}
}
}
if private_key.len() == 32 {
if let Ok(priv_key) = SecretKey::from_bytes(private_key.into()) {
if let Ok(result) = ecdh_decrypt_p256(ciphertext, &priv_key) {
return Ok(result);
}
}
}
Err(BottleError::InvalidKeyType)
}
#[cfg(feature = "ml-kem")]
pub fn mlkem768_encrypt<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
) -> Result<Vec<u8>> {
if public_key.len() != 1184 {
return Err(BottleError::InvalidKeyType);
}
let pub_key_array: [u8; 1184] = public_key
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let ek =
<Kem<MlKem768Params> as KemCore>::EncapsulationKey::from_bytes((&pub_key_array).into());
use rand_core_09::{CryptoRng as CryptoRng09, RngCore as RngCore09};
struct RngAdapter<'a, R: RngCore + CryptoRng>(&'a mut R);
impl<'a, R: RngCore + CryptoRng> RngCore09 for RngAdapter<'a, R> {
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.0.fill_bytes(dest)
}
}
impl<'a, R: RngCore + CryptoRng> CryptoRng09 for RngAdapter<'a, R> {}
let mut adapter = RngAdapter(rng);
use ml_kem::kem::Encapsulate;
let (ct, shared_secret) = ek
.encapsulate(&mut adapter)
.map_err(|_| BottleError::Encryption("ML-KEM encapsulation failed".to_string()))?;
let key = derive_key(&shared_secret);
let encrypted = encrypt_aes_gcm(&key, plaintext)?;
let mut result = ct.as_bytes().to_vec();
result.extend_from_slice(&encrypted);
Ok(result)
}
#[cfg(feature = "ml-kem")]
pub fn mlkem768_decrypt(ciphertext: &[u8], secret_key: &[u8]) -> Result<Vec<u8>> {
let dk_bytes = if secret_key.len() == 2400 {
secret_key
} else if secret_key.len() == 3584 {
&secret_key[..2400]
} else {
return Err(BottleError::InvalidKeyType);
};
let sec_key_array: [u8; 2400] = dk_bytes
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let dk =
<Kem<MlKem768Params> as KemCore>::DecapsulationKey::from_bytes((&sec_key_array).into());
const CT_SIZE: usize = 1088; if ciphertext.len() < CT_SIZE + 28 {
return Err(BottleError::InvalidFormat);
}
let mlkem_ct_bytes = &ciphertext[..CT_SIZE];
let ct_array: [u8; CT_SIZE] = mlkem_ct_bytes
.try_into()
.map_err(|_| BottleError::InvalidFormat)?;
let mlkem_ct: Array<u8, U1088> = ct_array.into();
let aes_ct = &ciphertext[CT_SIZE..];
use ml_kem::kem::Decapsulate;
let shared_secret = dk
.decapsulate(&mlkem_ct)
.map_err(|_| BottleError::Decryption("ML-KEM decapsulation failed".to_string()))?;
let key = derive_key(&shared_secret);
decrypt_aes_gcm(&key, aes_ct)
}
#[cfg(feature = "ml-kem")]
pub fn mlkem1024_encrypt<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
) -> Result<Vec<u8>> {
if public_key.len() != 1568 {
return Err(BottleError::InvalidKeyType);
}
let pub_key_array: [u8; 1568] = public_key
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let ek =
<Kem<MlKem1024Params> as KemCore>::EncapsulationKey::from_bytes((&pub_key_array).into());
use rand_core_09::{CryptoRng as CryptoRng09, RngCore as RngCore09};
struct RngAdapter<'a, R: RngCore + CryptoRng>(&'a mut R);
impl<'a, R: RngCore + CryptoRng> RngCore09 for RngAdapter<'a, R> {
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.0.fill_bytes(dest)
}
}
impl<'a, R: RngCore + CryptoRng> CryptoRng09 for RngAdapter<'a, R> {}
let mut adapter = RngAdapter(rng);
use ml_kem::kem::Encapsulate;
let (ct, shared_secret) = ek
.encapsulate(&mut adapter)
.map_err(|_| BottleError::Encryption("ML-KEM encapsulation failed".to_string()))?;
let key = derive_key(&shared_secret);
let encrypted = encrypt_aes_gcm(&key, plaintext)?;
let mut result = ct.as_bytes().to_vec();
result.extend_from_slice(&encrypted);
Ok(result)
}
#[cfg(feature = "ml-kem")]
pub fn mlkem1024_decrypt(ciphertext: &[u8], secret_key: &[u8]) -> Result<Vec<u8>> {
let dk_bytes = if secret_key.len() == 3168 {
secret_key
} else if secret_key.len() == 4736 {
&secret_key[..3168]
} else {
return Err(BottleError::InvalidKeyType);
};
let sec_key_array: [u8; 3168] = dk_bytes
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let dk =
<Kem<MlKem1024Params> as KemCore>::DecapsulationKey::from_bytes((&sec_key_array).into());
const CT_SIZE: usize = 1568; if ciphertext.len() < CT_SIZE + 28 {
return Err(BottleError::InvalidFormat);
}
let ct_array: [u8; CT_SIZE] = ciphertext[..CT_SIZE]
.try_into()
.map_err(|_| BottleError::InvalidFormat)?;
let mlkem_ct: Array<u8, U1568> = ct_array.into();
let aes_ct = &ciphertext[CT_SIZE..];
use ml_kem::kem::Decapsulate;
let shared_secret = dk
.decapsulate(&mlkem_ct)
.map_err(|_| BottleError::Decryption("ML-KEM decapsulation failed".to_string()))?;
let key = derive_key(&shared_secret);
decrypt_aes_gcm(&key, aes_ct)
}
#[cfg(feature = "ml-kem")]
pub fn hybrid_encrypt_mlkem768_x25519<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
mlkem_pub: &[u8],
x25519_pub: &[u8],
) -> Result<Vec<u8>> {
let mlkem_ct = mlkem768_encrypt(rng, plaintext, mlkem_pub)?;
let x25519_pub_bytes: [u8; 32] = x25519_pub
.try_into()
.map_err(|_| BottleError::InvalidKeyType)?;
let x25519_pub_key = X25519PublicKey::from(x25519_pub_bytes);
let x25519_ct = ecdh_encrypt_x25519(rng, plaintext, &x25519_pub_key)?;
let mut result = Vec::new();
result.extend_from_slice(&(mlkem_ct.len() as u32).to_le_bytes());
result.extend_from_slice(&mlkem_ct);
result.extend_from_slice(&x25519_ct);
Ok(result)
}
#[cfg(feature = "ml-kem")]
pub fn hybrid_decrypt_mlkem768_x25519(
ciphertext: &[u8],
mlkem_sec: &[u8],
x25519_sec: &[u8; 32],
) -> Result<Vec<u8>> {
if ciphertext.len() < 4 {
return Err(BottleError::InvalidFormat);
}
let mlkem_len = u32::from_le_bytes(ciphertext[..4].try_into().unwrap()) as usize;
if ciphertext.len() < 4 + mlkem_len {
return Err(BottleError::InvalidFormat);
}
let mlkem_ct = &ciphertext[4..4 + mlkem_len];
let x25519_ct = &ciphertext[4 + mlkem_len..];
match mlkem768_decrypt(mlkem_ct, mlkem_sec) {
Ok(plaintext) => Ok(plaintext),
Err(_) => ecdh_decrypt_x25519(x25519_ct, x25519_sec),
}
}
fn derive_key(shared_secret: &[u8]) -> [u8; 32] {
use sha2::Digest;
use sha2::Sha256;
let mut hasher = Sha256::new();
hasher.update(shared_secret);
let hash = hasher.finalize();
let mut key = [0u8; 32];
key.copy_from_slice(&hash);
key
}
fn encrypt_aes_gcm(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
use ring::aead::{self, BoundKey, NonceSequence, UnboundKey};
use ring::rand::{SecureRandom, SystemRandom};
let rng = SystemRandom::new();
let mut nonce_bytes = [0u8; 12];
rng.fill(&mut nonce_bytes)
.map_err(|_| BottleError::Encryption("RNG failure".to_string()))?;
let _nonce = aead::Nonce::assume_unique_for_key(nonce_bytes);
let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key)
.map_err(|_| BottleError::Encryption("Key creation failed".to_string()))?;
struct SingleNonceSequence([u8; 12]);
impl NonceSequence for SingleNonceSequence {
fn advance(&mut self) -> std::result::Result<aead::Nonce, ring::error::Unspecified> {
Ok(aead::Nonce::assume_unique_for_key(self.0))
}
}
let mut sealing_key = aead::SealingKey::new(unbound_key, SingleNonceSequence(nonce_bytes));
let mut in_out = plaintext.to_vec();
let tag_len = sealing_key.algorithm().tag_len();
in_out.reserve(tag_len);
sealing_key
.seal_in_place_append_tag(aead::Aad::empty(), &mut in_out)
.map_err(|_| BottleError::Encryption("Encryption failed".to_string()))?;
let mut result = nonce_bytes.to_vec();
result.extend_from_slice(&in_out);
Ok(result)
}
fn decrypt_aes_gcm(key: &[u8; 32], ciphertext: &[u8]) -> Result<Vec<u8>> {
use ring::aead::{self, BoundKey, NonceSequence, OpeningKey, UnboundKey};
if ciphertext.len() < 12 {
return Err(BottleError::InvalidFormat);
}
let nonce_bytes: [u8; 12] = ciphertext[..12]
.try_into()
.map_err(|_| BottleError::Decryption("Invalid nonce length".to_string()))?;
let _nonce = aead::Nonce::assume_unique_for_key(nonce_bytes);
let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key)
.map_err(|_| BottleError::Decryption("Key creation failed".to_string()))?;
struct SingleNonceSequence([u8; 12]);
impl NonceSequence for SingleNonceSequence {
fn advance(&mut self) -> std::result::Result<aead::Nonce, ring::error::Unspecified> {
Ok(aead::Nonce::assume_unique_for_key(self.0))
}
}
let mut opening_key = OpeningKey::new(unbound_key, SingleNonceSequence(nonce_bytes));
let mut in_out = ciphertext[12..].to_vec();
let plaintext = opening_key
.open_in_place(aead::Aad::empty(), &mut in_out)
.map_err(|_| BottleError::Decryption("Decryption failed".to_string()))?;
Ok(plaintext.to_vec())
}
pub fn rsa_encrypt<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &rsa::RsaPublicKey,
) -> Result<Vec<u8>> {
use rsa::Oaep;
use sha2::Sha256;
let padding = Oaep::new::<Sha256>();
public_key
.encrypt(rng, padding, plaintext)
.map_err(|e| BottleError::Encryption(format!("RSA encryption failed: {}", e)))
}
pub fn rsa_decrypt(ciphertext: &[u8], rsa_key: &crate::keys::RsaKey) -> Result<Vec<u8>> {
rsa_key.decrypt(ciphertext)
}