use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use hkdf::Hkdf;
use rand::RngCore;
use sha2::Sha256;
use x25519_dalek::{PublicKey, StaticSecret};
use zeroize::Zeroize;
use std::fmt;
#[derive(Debug)]
pub enum CryptoError {
InvalidBase64,
InvalidKeyLength,
DecryptionFailed,
}
impl fmt::Display for CryptoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CryptoError::InvalidBase64 => write!(f, "invalid base64 encoding"),
CryptoError::InvalidKeyLength => write!(f, "invalid key length"),
CryptoError::DecryptionFailed => write!(f, "decryption failed"),
}
}
}
impl std::error::Error for CryptoError {}
pub struct Keypair {
secret: StaticSecret,
public: PublicKey,
}
impl Keypair {
pub fn generate() -> Self {
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let secret = StaticSecret::from(bytes);
bytes.zeroize();
let public = PublicKey::from(&secret);
Keypair { secret, public }
}
pub fn public_key(&self) -> &PublicKey {
&self.public
}
pub fn derive_shared_key(&self, peer_public: &PublicKey) -> [u8; 32] {
let shared_secret = self.secret.diffie_hellman(peer_public);
let hkdf = Hkdf::<Sha256>::new(Some(b"opencord-v1"), shared_secret.as_bytes());
let mut key = [0u8; 32];
hkdf.expand(b"pairwise-key", &mut key)
.expect("32 bytes is a valid HKDF output length");
key
}
}
pub fn encode_public_key(key: &PublicKey) -> String {
BASE64.encode(key.as_bytes())
}
pub fn decode_public_key(encoded: &str) -> Result<PublicKey, CryptoError> {
let bytes = BASE64.decode(encoded).map_err(|_| CryptoError::InvalidBase64)?;
let array: [u8; 32] = bytes
.try_into()
.map_err(|_| CryptoError::InvalidKeyLength)?;
Ok(PublicKey::from(array))
}
pub fn generate_group_key() -> [u8; 32] {
let mut key = [0u8; 32];
rand::rng().fill_bytes(&mut key);
key
}
pub fn encrypt(key: &[u8; 32], plaintext: &[u8]) -> (String, String) {
let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is valid for AES-256");
let nonce_bytes: [u8; 12] = {
let mut buf = [0u8; 12];
rand::rng().fill_bytes(&mut buf);
buf
};
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.expect("encryption should not fail with valid key and nonce");
(BASE64.encode(nonce_bytes), BASE64.encode(ciphertext))
}
pub fn decrypt(
key: &[u8; 32],
nonce_b64: &str,
ciphertext_b64: &str,
) -> Result<Vec<u8>, CryptoError> {
let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is valid for AES-256");
let nonce_bytes = BASE64
.decode(nonce_b64)
.map_err(|_| CryptoError::InvalidBase64)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = BASE64
.decode(ciphertext_b64)
.map_err(|_| CryptoError::InvalidBase64)?;
cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|_| CryptoError::DecryptionFailed)
}
pub fn encrypt_group_key(pairwise_key: &[u8; 32], group_key: &[u8; 32]) -> String {
let (nonce_b64, ciphertext_b64) = encrypt(pairwise_key, group_key);
format!("{nonce_b64}:{ciphertext_b64}")
}
pub fn decrypt_group_key(
pairwise_key: &[u8; 32],
encrypted: &str,
) -> Result<[u8; 32], CryptoError> {
let parts: Vec<&str> = encrypted.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(CryptoError::InvalidBase64);
}
let mut plaintext = decrypt(pairwise_key, parts[0], parts[1])?;
let key: [u8; 32] = plaintext
.as_slice()
.try_into()
.map_err(|_| CryptoError::InvalidKeyLength)?;
plaintext.zeroize();
Ok(key)
}
#[cfg(unix)]
pub fn mlock_key(key: &[u8; 32]) -> bool {
unsafe { libc::mlock(key.as_ptr().cast(), 32) == 0 }
}
#[cfg(not(unix))]
pub fn mlock_key(_key: &[u8; 32]) -> bool {
false
}
#[cfg(unix)]
pub fn munlock_key(key: &[u8; 32]) {
unsafe {
libc::munlock(key.as_ptr().cast(), 32);
}
}
#[cfg(not(unix))]
pub fn munlock_key(_key: &[u8; 32]) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dh_symmetry() {
let alice = Keypair::generate();
let bob = Keypair::generate();
let alice_shared = alice.derive_shared_key(bob.public_key());
let bob_shared = bob.derive_shared_key(alice.public_key());
assert_eq!(alice_shared, bob_shared, "DH should be symmetric");
}
#[test]
fn encrypt_decrypt_roundtrip() {
let key = generate_group_key();
let plaintext = b"hello opencord!";
let (nonce, ciphertext) = encrypt(&key, plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn wrong_key_rejected() {
let key1 = generate_group_key();
let key2 = generate_group_key();
let plaintext = b"secret stuff";
let (nonce, ciphertext) = encrypt(&key1, plaintext);
let result = decrypt(&key2, &nonce, &ciphertext);
assert!(result.is_err(), "decryption with wrong key should fail");
assert!(
matches!(result.unwrap_err(), CryptoError::DecryptionFailed),
"should be DecryptionFailed"
);
}
#[test]
fn group_key_roundtrip() {
let alice = Keypair::generate();
let bob = Keypair::generate();
let pairwise = alice.derive_shared_key(bob.public_key());
let group_key = generate_group_key();
let encrypted = encrypt_group_key(&pairwise, &group_key);
let bob_pairwise = bob.derive_shared_key(alice.public_key());
let decrypted = decrypt_group_key(&bob_pairwise, &encrypted).unwrap();
assert_eq!(decrypted, group_key);
}
#[test]
fn public_key_encode_decode_roundtrip() {
let kp = Keypair::generate();
let encoded = encode_public_key(kp.public_key());
let decoded = decode_public_key(&encoded).unwrap();
assert_eq!(decoded.as_bytes(), kp.public_key().as_bytes());
}
#[test]
fn invalid_base64_rejected() {
let result = decode_public_key("not-valid-base64!!!");
assert!(matches!(result, Err(CryptoError::InvalidBase64)));
}
#[test]
fn wrong_length_key_rejected() {
let short = BASE64.encode([0u8; 16]);
let result = decode_public_key(&short);
assert!(matches!(result, Err(CryptoError::InvalidKeyLength)));
}
}