use crate::{
errors::KeyManagementError, DecryptionStrategy, EncryptedMessage, KeyIdentifier, PublicKey,
SentMessage, SharedEncryption,
};
use secrecy::ExposeSecret;
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert::TryFrom,
error::Error as StdError,
fmt::{self, Debug},
hash::{Hash, Hasher},
marker::PhantomData,
};
use thiserror::Error;
pub fn encrypt<T, E>(
encryption: &E,
sender_key_id: &KeyIdentifier,
receiver_pubkey: &[u8],
value: &T,
ephemeral_key_input: &[u8],
) -> Result<EncryptedMessage, Error>
where
E: SharedEncryption + ?Sized,
T: Serialize + ?Sized,
{
let message = serde_json::to_vec(value).map_err(Error::serializing)?;
Ok(encryption.shared_encrypt(
sender_key_id,
receiver_pubkey,
&message,
ephemeral_key_input,
)?)
}
pub fn decrypt<T, E>(
encryption: &E,
key_id: &KeyIdentifier,
decryption_strategy: DecryptionStrategy,
) -> Result<T, Error>
where
E: SharedEncryption + ?Sized,
T: DeserializeOwned,
{
let message = encryption.shared_decrypt(key_id, decryption_strategy)?;
serde_json::from_slice(message.expose_secret()).map_err(Error::serializing)
}
#[derive(Deserialize, Serialize)]
pub struct Encrypted<T> {
payload: EncryptedMessage,
type_: PhantomData<T>,
}
impl<T> Encrypted<T> {
pub fn payload(&self) -> &EncryptedMessage {
&self.payload
}
pub fn from_bytes_unchecked(bytes: &[u8]) -> Result<Self, Error> {
let payload = EncryptedMessage::from_bytes(bytes)?;
Ok(Encrypted {
payload,
type_: PhantomData,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
self.payload.to_bytes()
}
}
impl<T: Serialize> Encrypted<T> {
pub fn encrypt<E>(
encryption: &E,
sender_key_id: &KeyIdentifier,
receiver_pubkey: &[u8],
value: &T,
ephemeral_key_input: &[u8],
) -> Result<Self, Error>
where
E: SharedEncryption + ?Sized,
{
Ok(Encrypted {
payload: encrypt(
encryption,
sender_key_id,
receiver_pubkey,
value,
ephemeral_key_input,
)?,
type_: PhantomData,
})
}
}
impl<T: DeserializeOwned> Encrypted<T> {
pub fn decrypt_as_receiver<E>(
&self,
encryption: &E,
receiver_key_id: &KeyIdentifier,
) -> Result<T, Error>
where
E: SharedEncryption + ?Sized,
{
decrypt(
encryption,
receiver_key_id,
DecryptionStrategy::Receiver(self.payload.clone()),
)
}
pub fn decrypt_as_sender<E>(
&self,
encryption: &E,
sender_key_id: &KeyIdentifier,
receiver_pubkey: &[u8],
ephemeral_key_input: &[u8],
) -> Result<T, Error>
where
E: SharedEncryption + ?Sized,
{
let sent_msg = SentMessage {
ephemeral_key_input: ephemeral_key_input.to_vec(),
receiver_pubkey: PublicKey::try_from(receiver_pubkey).map_err(|e| {
Error::Encryption(KeyManagementError::ByteConversionError {
message: e.to_string(),
})
})?,
ciphertext: self.payload.ciphertext.clone(),
};
decrypt(
encryption,
sender_key_id,
DecryptionStrategy::Sender(sent_msg),
)
}
}
impl<T> Clone for Encrypted<T> {
fn clone(&self) -> Self {
Encrypted {
payload: self.payload.clone(),
type_: self.type_,
}
}
}
impl<T> Debug for Encrypted<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"<Encrypted value of type {}>",
std::any::type_name::<T>()
)
}
}
impl<T> PartialEq for Encrypted<T> {
fn eq(&self, other: &Self) -> bool {
self.payload == other.payload
}
}
impl<T> Eq for Encrypted<T> {}
impl<T> Hash for Encrypted<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.payload.hash(state);
self.type_.hash(state);
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("Encryption error: {0}")]
Encryption(#[from] KeyManagementError),
#[error("Serialization error: {0}")]
Serialization(#[source] Box<dyn StdError + Send + Sync>),
}
impl Error {
fn serializing<E>(e: E) -> Self
where
E: StdError + Send + Sync + 'static,
{
Error::Serialization(Box::new(e))
}
}
#[cfg(test)]
mod tests {
use crate::{
errors::KeyManagementError, in_memory_key_manager::InMemoryKeyManager, typed::Encrypted,
KeyManagement, KeyType, NewKeyId,
};
use serde::{de::DeserializeOwned, Serialize};
const EPHEMERAL_KEY_INPUT: &[u8] = &[0, 1, 2];
struct TestContext {
key_manager: InMemoryKeyManager,
alice_key: NewKeyId,
bob_key: NewKeyId,
}
impl TestContext {
fn new() -> Self {
let mut key_manager = InMemoryKeyManager::new(Default::default());
let alice_key = key_manager
.generate_keypair(KeyType::SharedEncryptionX25519)
.unwrap();
let bob_key = key_manager
.generate_keypair(KeyType::SharedEncryptionX25519)
.unwrap();
TestContext {
key_manager,
alice_key,
bob_key,
}
}
fn encrypt<T>(&self, value: &T) -> Result<Encrypted<T>, super::Error>
where
T: Serialize,
{
Encrypted::encrypt(
&self.key_manager,
&self.alice_key.id,
&self.bob_key.pubkey,
value,
EPHEMERAL_KEY_INPUT,
)
}
fn decrypt<T>(&self, encrypted: &Encrypted<T>) -> Result<T, super::Error>
where
T: DeserializeOwned,
{
encrypted.decrypt_as_receiver(&self.key_manager, &self.bob_key.id)
}
}
#[test]
fn encryption_roundtrips() {
let secret = String::from("shhh");
let ctx = TestContext::new();
let encrypted = ctx.encrypt(&secret).unwrap();
let decrypted = ctx.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, secret);
}
#[test]
fn sender_can_decrypt() {
let secret = String::from("shhh");
let ctx = TestContext::new();
let encrypted = ctx.encrypt(&secret).unwrap();
let decrypted = encrypted
.decrypt_as_sender(
&ctx.key_manager,
&ctx.alice_key.id,
&ctx.bob_key.pubkey,
EPHEMERAL_KEY_INPUT,
)
.unwrap();
assert_eq!(decrypted, secret);
}
#[test]
fn decrypting_wrong_type_fails() {
let ctx = TestContext::new();
let encrypted = ctx.encrypt(&33).unwrap();
let encrypted = Encrypted::<String>::from_bytes_unchecked(&encrypted.to_bytes()).unwrap();
let decrypted = ctx.decrypt(&encrypted);
assert!(matches!(decrypted, Err(super::Error::Serialization(_))));
}
#[test]
fn byte_serialization_roundtrips() {
let secret = String::from("shhh");
let ctx = TestContext::new();
let encrypted = ctx.encrypt(&secret).unwrap();
let bytes = encrypted.to_bytes();
let encrypted = Encrypted::<String>::from_bytes_unchecked(&bytes).unwrap();
let decrypted = ctx.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, secret);
}
#[test]
fn deserializing_garbage_bytes_fails() {
assert!(matches!(
Encrypted::<String>::from_bytes_unchecked(&[33]),
Err(super::Error::Encryption(
KeyManagementError::ByteConversionError { .. }
))
));
}
}