use super::keys::{PrivateKey, PublicKey};
use anyhow::{Result, anyhow};
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, KeyInit},
};
use rand::RngCore;
use x25519_dalek::{PublicKey as X25519Public, StaticSecret as X25519Secret};
const XNONCE_SIZE: usize = 24;
const KEY_SIZE: usize = 32;
const EPHEMERAL_PUBLIC_SIZE: usize = 32;
const WRAPPED_KEY_SIZE: usize = KEY_SIZE + 16;
pub fn encrypt_asymmetric(plaintext: &[u8], recipients: &[PublicKey]) -> Result<Vec<u8>> {
if recipients.is_empty() {
return Err(anyhow!("At least one recipient is required"));
}
let mut symmetric_key = [0u8; KEY_SIZE];
rand::thread_rng().fill_bytes(&mut symmetric_key);
let mut payload_nonce = [0u8; XNONCE_SIZE];
rand::thread_rng().fill_bytes(&mut payload_nonce);
let cipher = XChaCha20Poly1305::new_from_slice(&symmetric_key)
.map_err(|e| anyhow!("Cipher creation failed: {}", e))?;
let ciphertext = cipher
.encrypt(XNonce::from_slice(&payload_nonce), plaintext)
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
let mut output = Vec::new();
output.push(recipients.len() as u8);
for recipient in recipients {
let ephemeral_secret = X25519Secret::random_from_rng(rand::thread_rng());
let ephemeral_public = X25519Public::from(&ephemeral_secret);
let shared_secret = ephemeral_secret.diffie_hellman(&recipient.x25519);
let key_encryption_key = derive_key_encryption_key(shared_secret.as_bytes());
let mut key_nonce = [0u8; XNONCE_SIZE];
rand::thread_rng().fill_bytes(&mut key_nonce);
let key_cipher = XChaCha20Poly1305::new_from_slice(&key_encryption_key)
.map_err(|e| anyhow!("Key cipher creation failed: {}", e))?;
let wrapped_key = key_cipher
.encrypt(XNonce::from_slice(&key_nonce), symmetric_key.as_slice())
.map_err(|e| anyhow!("Key wrapping failed: {}", e))?;
output.extend_from_slice(ephemeral_public.as_bytes());
output.extend_from_slice(&key_nonce);
output.extend_from_slice(&wrapped_key);
}
output.extend_from_slice(&payload_nonce);
output.extend_from_slice(&ciphertext);
Ok(output)
}
pub fn decrypt_asymmetric(data: &[u8], private_key: &PrivateKey) -> Result<Vec<u8>> {
if data.is_empty() {
return Err(anyhow!("Empty ciphertext"));
}
let recipient_count = data[0] as usize;
if recipient_count == 0 {
return Err(anyhow!("No recipients in ciphertext"));
}
let per_recipient_size = EPHEMERAL_PUBLIC_SIZE + XNONCE_SIZE + WRAPPED_KEY_SIZE;
let recipients_section_size = recipient_count * per_recipient_size;
let header_size = 1 + recipients_section_size;
if data.len() < header_size + XNONCE_SIZE + 16 {
return Err(anyhow!("Ciphertext too short"));
}
let mut symmetric_key: Option<[u8; KEY_SIZE]> = None;
for i in 0..recipient_count {
let offset = 1 + i * per_recipient_size;
let ephemeral_bytes: [u8; 32] = data[offset..offset + 32].try_into().unwrap();
let ephemeral_public = X25519Public::from(ephemeral_bytes);
let key_nonce = &data[offset + 32..offset + 32 + XNONCE_SIZE];
let wrapped_key = &data[offset + 32 + XNONCE_SIZE..offset + per_recipient_size];
let shared_secret = private_key.x25519.diffie_hellman(&ephemeral_public);
let key_encryption_key = derive_key_encryption_key(shared_secret.as_bytes());
let key_cipher = XChaCha20Poly1305::new_from_slice(&key_encryption_key)
.map_err(|e| anyhow!("Key cipher creation failed: {}", e))?;
if let Ok(decrypted_key) = key_cipher.decrypt(XNonce::from_slice(key_nonce), wrapped_key)
&& decrypted_key.len() == KEY_SIZE
{
let mut key = [0u8; KEY_SIZE];
key.copy_from_slice(&decrypted_key);
symmetric_key = Some(key);
break;
}
}
let symmetric_key =
symmetric_key.ok_or_else(|| anyhow!("Could not decrypt: you may not be a recipient"))?;
let payload_nonce = &data[header_size..header_size + XNONCE_SIZE];
let ciphertext = &data[header_size + XNONCE_SIZE..];
let cipher = XChaCha20Poly1305::new_from_slice(&symmetric_key)
.map_err(|e| anyhow!("Cipher creation failed: {}", e))?;
let plaintext = cipher
.decrypt(XNonce::from_slice(payload_nonce), ciphertext)
.map_err(|_| anyhow!("Payload decryption failed: corrupted data"))?;
Ok(plaintext)
}
pub fn recipient_count(data: &[u8]) -> Option<u8> {
data.first().copied()
}
fn derive_key_encryption_key(shared_secret: &[u8]) -> [u8; KEY_SIZE] {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut result = [0u8; KEY_SIZE];
for i in 0..4 {
let mut hasher = DefaultHasher::new();
b"zimhide-key-derivation".hash(&mut hasher);
i.hash(&mut hasher);
shared_secret.hash(&mut hasher);
let hash = hasher.finish();
result[i * 8..(i + 1) * 8].copy_from_slice(&hash.to_le_bytes());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::keys::Keypair;
#[test]
fn test_asymmetric_single_recipient() {
let keypair = Keypair::generate();
let plaintext = b"Secret message for one recipient";
let encrypted =
encrypt_asymmetric(plaintext, std::slice::from_ref(&keypair.public)).unwrap();
let decrypted = decrypt_asymmetric(&encrypted, &keypair.private).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_asymmetric_multiple_recipients() {
let keypair1 = Keypair::generate();
let keypair2 = Keypair::generate();
let plaintext = b"Secret message for multiple recipients";
let recipients = [keypair1.public.clone(), keypair2.public.clone()];
let encrypted = encrypt_asymmetric(plaintext, &recipients).unwrap();
let decrypted1 = decrypt_asymmetric(&encrypted, &keypair1.private).unwrap();
let decrypted2 = decrypt_asymmetric(&encrypted, &keypair2.private).unwrap();
assert_eq!(plaintext.as_slice(), decrypted1.as_slice());
assert_eq!(plaintext.as_slice(), decrypted2.as_slice());
}
#[test]
fn test_non_recipient_cannot_decrypt() {
let recipient = Keypair::generate();
let non_recipient = Keypair::generate();
let plaintext = b"Secret message";
let encrypted =
encrypt_asymmetric(plaintext, std::slice::from_ref(&recipient.public)).unwrap();
let result = decrypt_asymmetric(&encrypted, &non_recipient.private);
assert!(result.is_err());
}
}