use crypto_box::{
PublicKey, SalsaBox, SecretKey,
aead::{Aead, AeadCore, OsRng},
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt;
use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::boxed_message::{BoxedMessage, is_boxed_message};
pub const KEY_SIZE: usize = 32;
pub type KeyBytes = [u8; KEY_SIZE];
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("couldn't decrypt message")]
DecryptionFailed,
#[error("encryption failed")]
EncryptionFailed,
#[error("failed to generate random bytes")]
RandomGenerationFailed,
#[error("invalid key length")]
InvalidKeyLength,
#[error("invalid message format")]
InvalidMessageFormat,
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct Keypair {
#[zeroize(skip)] pub public: KeyBytes,
pub private: KeyBytes,
}
impl fmt::Debug for Keypair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Keypair")
.field("public", &hex::encode(self.public))
.field("private", &"[REDACTED]")
.finish()
}
}
impl Keypair {
pub fn generate() -> Result<Self, CryptoError> {
let secret_key = SecretKey::generate(&mut OsRng);
let public_key = secret_key.public_key();
Ok(Self {
public: *public_key.as_bytes(),
private: secret_key.to_bytes(),
})
}
pub fn from_keys(public: KeyBytes, private: KeyBytes) -> Self {
Self { public, private }
}
pub fn public_string(&self) -> String {
hex::encode(self.public)
}
pub fn private_string(&self) -> String {
hex::encode(self.private)
}
pub fn into_encrypter(self, peer_public: KeyBytes) -> Encrypter {
Encrypter::new(self, peer_public)
}
pub fn into_decrypter(self) -> Decrypter {
Decrypter::new(self)
}
pub fn encrypter(&self, peer_public: KeyBytes) -> Encrypter {
let kp = Self::from_keys(self.public, self.private);
Encrypter::new(kp, peer_public)
}
pub fn decrypter(&self) -> Decrypter {
let kp = Self::from_keys(self.public, self.private);
Decrypter::new(kp)
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct Encrypter {
keypair: Keypair,
#[zeroize(skip)]
peer_public: KeyBytes,
#[zeroize(skip)] salsa_box: SalsaBox,
}
impl fmt::Debug for Encrypter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Encrypter")
.field("keypair", &"[REDACTED]")
.field("peer_public", &hex::encode(self.peer_public))
.finish()
}
}
impl Encrypter {
pub fn new(keypair: Keypair, peer_public: KeyBytes) -> Self {
let secret_key = SecretKey::from(keypair.private);
let public_key = PublicKey::from(peer_public);
let salsa_box = SalsaBox::new(&public_key, &secret_key);
Self {
keypair,
peer_public,
salsa_box,
}
}
pub fn encrypt(&self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
if is_boxed_message(message) {
return Ok(message.to_vec());
}
let boxed = self.encrypt_raw(message)?;
Ok(boxed.dump())
}
fn encrypt_raw(&self, message: &[u8]) -> Result<BoxedMessage, CryptoError> {
let nonce = SalsaBox::generate_nonce(&mut OsRng);
let ciphertext = self
.salsa_box
.encrypt(&nonce, message)
.map_err(|_| CryptoError::EncryptionFailed)?;
Ok(BoxedMessage {
schema_version: 1,
encrypter_public: self.keypair.public,
nonce: nonce.into(),
box_data: ciphertext,
})
}
}
pub struct Decrypter {
keypair: Keypair,
cache: RefCell<HashMap<KeyBytes, SalsaBox>>,
}
impl Drop for Decrypter {
fn drop(&mut self) {
self.keypair.private.zeroize();
}
}
impl fmt::Debug for Decrypter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Decrypter")
.field("keypair", &"[REDACTED]")
.field("cached_keys", &self.cache.borrow().len())
.finish()
}
}
impl Decrypter {
pub fn new(keypair: Keypair) -> Self {
Self {
keypair,
cache: RefCell::new(HashMap::new()),
}
}
pub fn decrypt(&self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
let boxed = BoxedMessage::load(message).map_err(|_| CryptoError::InvalidMessageFormat)?;
self.decrypt_boxed(&boxed)
}
fn decrypt_boxed(&self, boxed: &BoxedMessage) -> Result<Vec<u8>, CryptoError> {
let nonce = boxed.nonce.into();
let mut cache = self.cache.borrow_mut();
let salsa_box = cache.entry(boxed.encrypter_public).or_insert_with(|| {
let secret_key = SecretKey::from(self.keypair.private);
let peer_public = PublicKey::from(boxed.encrypter_public);
SalsaBox::new(&peer_public, &secret_key)
});
salsa_box
.decrypt(&nonce, boxed.box_data.as_slice())
.map_err(|_| CryptoError::DecryptionFailed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generation() {
let kp = Keypair::generate().unwrap();
assert_eq!(kp.public.len(), 32);
assert_eq!(kp.private.len(), 32);
assert_eq!(kp.public_string().len(), 64);
assert_eq!(kp.private_string().len(), 64);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let sender_kp = Keypair::generate().unwrap();
let receiver_kp = Keypair::generate().unwrap();
let receiver_public = receiver_kp.public;
let encrypter = sender_kp.into_encrypter(receiver_public);
let decrypter = receiver_kp.into_decrypter();
let plaintext = b"Hello, World!";
let encrypted = encrypter.encrypt(plaintext).unwrap();
let decrypted = decrypter.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_already_encrypted_passthrough() {
let kp = Keypair::generate().unwrap();
let encrypter = kp.encrypter(kp.public);
let plaintext = b"secret";
let encrypted = encrypter.encrypt(plaintext).unwrap();
let double_encrypted = encrypter.encrypt(&encrypted).unwrap();
assert_eq!(encrypted, double_encrypted);
}
#[test]
fn test_keypair_debug_redacts_private_key() {
let kp = Keypair::generate().unwrap();
let debug_output = format!("{:?}", kp);
assert!(debug_output.contains("[REDACTED]"));
assert!(!debug_output.contains(&kp.private_string()));
}
}