use chacha20::cipher::generic_array::{typenum::Unsigned, GenericArray};
use chacha20poly1305::aead::Aead;
use crypto_box::aead::AeadCore;
use litl::{impl_debug_as_litl, impl_single_tagged_data_serde, SingleTaggedData, TaggedDataError};
use rand08::rngs::OsRng;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_derive::{Deserialize, Serialize};
use std::{array::TryFromSliceError, borrow::Cow, fmt::Debug, marker::PhantomData, ops::Deref};
use thiserror::Error;
use zeroize::Zeroizing;
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum RecipientID {
RecipientV1(crypto_box::PublicKey),
}
impl RecipientID {
pub fn everyone() -> Self {
RecipientSecret::everyone().pub_id()
}
}
impl SingleTaggedData for RecipientID {
const TAG: &'static str = "recipientID";
fn as_bytes(&self) -> Cow<[u8]> {
match self {
RecipientID::RecipientV1(key) => Cow::from(key.as_bytes().as_ref()),
}
}
fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
where
Self: Sized,
{
let key_bytes: [u8; crypto_box::KEY_SIZE] = data
.try_into()
.map_err(|err| TaggedDataError::data_error(Into::<RecipientIDError>::into(err)))?;
Ok(RecipientID::RecipientV1(crypto_box::PublicKey::from(
key_bytes,
)))
}
}
#[derive(Debug, Error)]
pub enum RecipientIDError {
#[error("Invalid RecipientID length")]
InvalidLength(#[from] TryFromSliceError),
}
impl_single_tagged_data_serde!(RecipientID);
impl_debug_as_litl!(RecipientID);
impl Deref for RecipientID {
type Target = crypto_box::PublicKey;
fn deref(&self) -> &Self::Target {
match self {
RecipientID::RecipientV1(pub_key) => pub_key,
}
}
}
impl Ord for RecipientID {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(RecipientID::RecipientV1(pub_key1), RecipientID::RecipientV1(pub_key2)) => {
pub_key1.as_bytes().cmp(pub_key2.as_bytes())
}
}
}
}
impl PartialOrd for RecipientID {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Serialize, Deserialize)]
pub struct EncrFromAnon<T> {
from_anon: RecipientID,
pub recipient: RecipientID,
#[serde(with = "litl::raw_data_serde")]
encrypted: Vec<u8>,
#[serde(skip)]
_marker: PhantomData<T>,
}
impl<T> Debug for EncrFromAnon<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&litl::to_string_pretty(&self).unwrap())
}
}
impl<T> Clone for EncrFromAnon<T> {
fn clone(&self) -> Self {
Self {
from_anon: self.from_anon.clone(),
recipient: self.recipient.clone(),
encrypted: self.encrypted.clone(),
_marker: PhantomData,
}
}
}
impl<T> PartialEq for EncrFromAnon<T> {
fn eq(&self, other: &Self) -> bool {
self.from_anon == other.from_anon
&& self.recipient == other.recipient
&& self.encrypted == other.encrypted
}
}
impl<T> Eq for EncrFromAnon<T> {}
impl RecipientID {
pub fn encrypt_from_anon<'de, T: Serialize + Deserialize<'de>>(
&self,
data: &T,
) -> EncrFromAnon<T> {
let ephemeral_sender_secret = RecipientSecret::new_random();
let ephemeral_sender = ephemeral_sender_secret.pub_id();
let nonce = nonce_for(&ephemeral_sender, self);
let chacha_box = crypto_box::ChaChaBox::new(self, &ephemeral_sender_secret);
EncrFromAnon {
from_anon: ephemeral_sender,
recipient: self.clone(),
encrypted: chacha_box
.encrypt(&nonce, litl::to_vec(data).unwrap().as_slice())
.expect("Unable to encrypt"),
_marker: PhantomData,
}
}
}
#[derive(Error, Debug)]
pub enum AsymmDecryptionError {
#[error("Decryption error.")]
DecryptionError,
#[error("Error converting from decrypted bytes.")]
DeserializeError(litl::Error),
}
type Nonce = GenericArray<u8, <crypto_box::ChaChaBox as AeadCore>::NonceSize>;
fn nonce_for(
ephemeral_pubkey: &crypto_box::PublicKey,
recipient_pubkey: &crypto_box::PublicKey,
) -> Nonce {
let mut hasher = blake3::Hasher::new();
hasher.update(ephemeral_pubkey.as_bytes());
hasher.update(recipient_pubkey.as_bytes());
*Nonce::from_slice(
&hasher.finalize().as_bytes()
[0..<crypto_box::ChaChaBox as AeadCore>::NonceSize::to_usize()],
)
}
pub enum RecipientSecret {
RecipientSecretV1(crypto_box::SecretKey),
}
impl SingleTaggedData for RecipientSecret {
const TAG: &'static str = "recipientSecretV1";
fn as_bytes(&self) -> Cow<[u8]> {
match self {
RecipientSecret::RecipientSecretV1(key) => Cow::from(key.as_bytes().as_ref()),
}
}
fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
where
Self: Sized,
{
let key_bytes: [u8; crypto_box::KEY_SIZE] = data
.try_into()
.map_err(|err| TaggedDataError::data_error(Into::<RecipientSecretError>::into(err)))?;
Ok(RecipientSecret::RecipientSecretV1(
crypto_box::SecretKey::from(key_bytes),
))
}
}
#[derive(Debug, Error)]
pub enum RecipientSecretError {
#[error("Invalid RecipientSecret length")]
InvalidLength(#[from] TryFromSliceError),
}
impl_single_tagged_data_serde!(RecipientSecret);
impl Deref for RecipientSecret {
type Target = crypto_box::SecretKey;
fn deref(&self) -> &Self::Target {
match self {
RecipientSecret::RecipientSecretV1(secret_key) => secret_key,
}
}
}
impl RecipientSecret {
pub fn new_random() -> Self {
RecipientSecret::RecipientSecretV1(crypto_box::SecretKey::generate(&mut OsRng {}))
}
pub fn everyone() -> Self {
RecipientSecret::RecipientSecretV1(
crypto_box::SecretKey::from([0; crypto_box::KEY_SIZE]),
)
}
pub fn pub_id(&self) -> RecipientID {
match self {
RecipientSecret::RecipientSecretV1(secret_key) => {
RecipientID::RecipientV1(secret_key.public_key())
}
}
}
pub fn decrypt<T: DeserializeOwned>(
&self,
encr: &EncrFromAnon<T>,
) -> Result<T, AsymmDecryptionError> {
let recipient_pubkey = self.pub_id();
let nonce = nonce_for(&encr.from_anon, &recipient_pubkey);
let chacha_box = crypto_box::ChaChaBox::new(&encr.from_anon, self);
let plaintext = Zeroizing::new(
chacha_box
.decrypt(&nonce, encr.encrypted.as_slice())
.map_err(|_aead_err| AsymmDecryptionError::DecryptionError)?,
);
litl::from_slice_owned::<T>(&plaintext).map_err(AsymmDecryptionError::DeserializeError)
}
}
#[cfg(test)]
mod test {
use serde_derive::{Deserialize, Serialize};
use crate::asymm_encr::RecipientSecret;
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
struct TestData {
bla: [u8; 4],
}
#[test]
fn encryption_roundtrip_works() {
let data = TestData { bla: [1, 2, 3, 4] };
let recipient = RecipientSecret::new_random();
let encrypted = recipient.pub_id().encrypt_from_anon(&data);
println!("{}", litl::to_string_pretty(&encrypted).unwrap());
let decrypted = recipient.decrypt(&encrypted);
assert_eq!(decrypted.unwrap(), data);
}
#[test]
fn can_not_decrypt_with_wrong_secret_key() {
let data = TestData { bla: [1, 2, 3, 4] };
let recipient = RecipientSecret::new_random();
let encrypted = recipient.pub_id().encrypt_from_anon(&data);
let fake_recipient = RecipientSecret::new_random();
let decrypted = fake_recipient.decrypt(&encrypted);
assert!(matches!(decrypted, Err(_),));
}
#[test]
fn can_not_decrypt_invalid_ciphertext() {
let data = TestData { bla: [1, 2, 3, 4] };
let recipient = RecipientSecret::new_random();
let mut encrypted = recipient.pub_id().encrypt_from_anon(&data);
encrypted.encrypted = vec![0, 0, 0, 1, 6];
let decrypted = recipient.decrypt(&encrypted);
assert!(matches!(decrypted, Err(_),));
}
}