tpfs_krypt 7.1.8

An interface for accessing secrets
Documentation
//! Tools to encrypt and decrypt values while keeping their type around
//!
//! # Example
//! ```rust
//! use rand::{distributions::Standard, Rng};
//! use tpfs_krypt::{Encrypted, KeyManagement, KeyType, NewKeyId, SharedEncryption};
//!
//! fn alice_and_bob_talk<E: SharedEncryption + ?Sized>(
//!     encryption: &E,
//!     alice: &NewKeyId,
//!     bob: &NewKeyId,
//! ) {
//!     let message = String::from("Hi Bob!");
//!     let ephemeral_key_input = rand::thread_rng()
//!         .sample_iter(Standard)
//!         .take(32)
//!         .collect::<Vec<u8>>();
//!     let encrypted = Encrypted::encrypt(
//!         encryption,
//!         &alice.id,
//!         &bob.pubkey,
//!         &message,
//!         &ephemeral_key_input,
//!     ).unwrap();
//!     let decrypted = encrypted.decrypt_as_receiver(encryption, &bob.id).unwrap();
//!     assert_eq!(message, decrypted);
//! }
//!
//! let mut key_manager = tpfs_krypt::from_config(tpfs_krypt::config::KryptConfig {
//!     key_manager_config: tpfs_krypt::config::KeyManagerConfig::InMemoryKeyManager(
//!         Default::default()
//!     ),
//! }).unwrap();
//! let alice = key_manager.generate_keypair(KeyType::SharedEncryptionX25519).unwrap();
//! let bob = key_manager.generate_keypair(KeyType::SharedEncryptionX25519).unwrap();
//! alice_and_bob_talk(&*key_manager, &alice, &bob);
//! ```

use crate::{
    errors::KeyManagementError, DecryptionStrategy, EncryptedMessage, KeyIdentifier, PublicKey,
    SentMessage, SharedEncryption,
};
use secrecy::ExposeSecret;
use serde::{de::DeserializeOwned, Serialize};
use std::{
    convert::TryFrom,
    error::Error as StdError,
    fmt::{self, Debug},
    hash::{Hash, Hasher},
    marker::PhantomData,
};
use thiserror::Error;

/// Encrypts a value
///
/// See `SharedEncryption::shared_encrypt`
pub fn encrypt<T, E>(
    encryption: &E,
    sender_key_id: &KeyIdentifier,
    receiver_pubkey: &[u8],
    value: &T,
    ephemeral_key_input: &[u8],
) -> Result<EncryptedMessage, Error>
where
    E: SharedEncryption + ?Sized,
    T: Serialize + ?Sized,
{
    let message = serde_json::to_vec(value).map_err(Error::serializing)?;
    Ok(encryption.shared_encrypt(
        sender_key_id,
        receiver_pubkey,
        &message,
        ephemeral_key_input,
    )?)
}

/// Decrypts a value
///
/// `key_id` is the sender's key identifier if the decryption strategy is `Sender`, otherwise it is
/// the receiver's key identifier.
///
/// See `SharedEncryption::shared_decrypt`
pub fn decrypt<T, E>(
    encryption: &E,
    key_id: &KeyIdentifier,
    decryption_strategy: DecryptionStrategy,
) -> Result<T, Error>
where
    E: SharedEncryption + ?Sized,
    T: DeserializeOwned,
{
    let message = encryption.shared_decrypt(key_id, decryption_strategy)?;
    serde_json::from_slice(message.expose_secret()).map_err(Error::serializing)
}

/// Encrypted value of type `T`
#[derive(Deserialize, Serialize)]
pub struct Encrypted<T> {
    payload: EncryptedMessage,
    type_: PhantomData<T>,
}

impl<T> Encrypted<T> {
    /// Returns the encrypted payload
    pub fn payload(&self) -> &EncryptedMessage {
        &self.payload
    }

    /// Creates an `Encrypted` value from bytes
    ///
    /// Note that this method cannot check that the bytes really are the result of encrypting a
    /// value of type `T`, so callers should ensure that. This does not cause unsoundness though;
    /// an attempt to decrypt an invalid payload would return an error.
    ///
    /// This is useful to convert from representations that use bytes and do not rely on `serde`.
    pub fn from_bytes_unchecked(bytes: &[u8]) -> Result<Self, Error> {
        let payload = EncryptedMessage::from_bytes(bytes)?;
        Ok(Encrypted {
            payload,
            type_: PhantomData,
        })
    }

    /// Serializes this encrypted value to bytes
    ///
    /// This is the same as calling `to_bytes` on the inner `EncryptedMessage`. The primary use
    /// case is to support formats that need bytes and do not rely on `serde`.
    pub fn to_bytes(&self) -> Vec<u8> {
        self.payload.to_bytes()
    }
}

impl<T: Serialize> Encrypted<T> {
    /// Encrypts a value
    ///
    /// See `SharedEncryption::shared_encrypt`
    pub fn encrypt<E>(
        encryption: &E,
        sender_key_id: &KeyIdentifier,
        receiver_pubkey: &[u8],
        value: &T,
        ephemeral_key_input: &[u8],
    ) -> Result<Self, Error>
    where
        E: SharedEncryption + ?Sized,
    {
        Ok(Encrypted {
            payload: encrypt(
                encryption,
                sender_key_id,
                receiver_pubkey,
                value,
                ephemeral_key_input,
            )?,
            type_: PhantomData,
        })
    }
}

impl<T: DeserializeOwned> Encrypted<T> {
    /// Decrypts a received value
    pub fn decrypt_as_receiver<E>(
        &self,
        encryption: &E,
        receiver_key_id: &KeyIdentifier,
    ) -> Result<T, Error>
    where
        E: SharedEncryption + ?Sized,
    {
        decrypt(
            encryption,
            receiver_key_id,
            DecryptionStrategy::Receiver(self.payload.clone()),
        )
    }

    /// Decrypts a sent value
    ///
    /// The parameters match the ones used to encrypt.
    pub fn decrypt_as_sender<E>(
        &self,
        encryption: &E,
        sender_key_id: &KeyIdentifier,
        receiver_pubkey: &[u8],
        ephemeral_key_input: &[u8],
    ) -> Result<T, Error>
    where
        E: SharedEncryption + ?Sized,
    {
        let sent_msg = SentMessage {
            ephemeral_key_input: ephemeral_key_input.to_vec(),
            receiver_pubkey: PublicKey::try_from(receiver_pubkey).map_err(|e| {
                Error::Encryption(KeyManagementError::ByteConversionError {
                    message: e.to_string(),
                })
            })?,
            ciphertext: self.payload.ciphertext.clone(),
        };
        decrypt(
            encryption,
            sender_key_id,
            DecryptionStrategy::Sender(sent_msg),
        )
    }
}

impl<T> Clone for Encrypted<T> {
    fn clone(&self) -> Self {
        Encrypted {
            payload: self.payload.clone(),
            type_: self.type_,
        }
    }
}

impl<T> Debug for Encrypted<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "<Encrypted value of type {}>",
            std::any::type_name::<T>()
        )
    }
}

impl<T> PartialEq for Encrypted<T> {
    fn eq(&self, other: &Self) -> bool {
        self.payload == other.payload
    }
}

impl<T> Eq for Encrypted<T> {}

impl<T> Hash for Encrypted<T> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.payload.hash(state);
        self.type_.hash(state);
    }
}

/// Error type returned when encrypting or decrypting a value
#[derive(Debug, Error)]
pub enum Error {
    /// Error when encrypting or decrypting bytes
    #[error("Encryption error: {0}")]
    Encryption(#[from] KeyManagementError),
    /// Error when serializing or deserializing
    #[error("Serialization error: {0}")]
    Serialization(#[source] Box<dyn StdError + Send + Sync>),
}

impl Error {
    fn serializing<E>(e: E) -> Self
    where
        E: StdError + Send + Sync + 'static,
    {
        Error::Serialization(Box::new(e))
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        errors::KeyManagementError, in_memory_key_manager::InMemoryKeyManager, typed::Encrypted,
        KeyManagement, KeyType, NewKeyId,
    };
    use serde::{de::DeserializeOwned, Serialize};

    const EPHEMERAL_KEY_INPUT: &[u8] = &[0, 1, 2];

    struct TestContext {
        key_manager: InMemoryKeyManager,
        alice_key: NewKeyId,
        bob_key: NewKeyId,
    }

    impl TestContext {
        fn new() -> Self {
            let mut key_manager = InMemoryKeyManager::new(Default::default());
            let alice_key = key_manager
                .generate_keypair(KeyType::SharedEncryptionX25519)
                .unwrap();
            let bob_key = key_manager
                .generate_keypair(KeyType::SharedEncryptionX25519)
                .unwrap();
            TestContext {
                key_manager,
                alice_key,
                bob_key,
            }
        }

        fn encrypt<T>(&self, value: &T) -> Result<Encrypted<T>, super::Error>
        where
            T: Serialize,
        {
            Encrypted::encrypt(
                &self.key_manager,
                &self.alice_key.id,
                &self.bob_key.pubkey,
                value,
                EPHEMERAL_KEY_INPUT,
            )
        }

        fn decrypt<T>(&self, encrypted: &Encrypted<T>) -> Result<T, super::Error>
        where
            T: DeserializeOwned,
        {
            encrypted.decrypt_as_receiver(&self.key_manager, &self.bob_key.id)
        }
    }

    #[test]
    fn encryption_roundtrips() {
        let secret = String::from("shhh");
        let ctx = TestContext::new();
        let encrypted = ctx.encrypt(&secret).unwrap();
        let decrypted = ctx.decrypt(&encrypted).unwrap();
        assert_eq!(decrypted, secret);
    }

    #[test]
    fn sender_can_decrypt() {
        let secret = String::from("shhh");
        let ctx = TestContext::new();
        let encrypted = ctx.encrypt(&secret).unwrap();
        let decrypted = encrypted
            .decrypt_as_sender(
                &ctx.key_manager,
                &ctx.alice_key.id,
                &ctx.bob_key.pubkey,
                EPHEMERAL_KEY_INPUT,
            )
            .unwrap();
        assert_eq!(decrypted, secret);
    }

    #[test]
    fn decrypting_wrong_type_fails() {
        let ctx = TestContext::new();
        let encrypted = ctx.encrypt(&33).unwrap();
        let encrypted = Encrypted::<String>::from_bytes_unchecked(&encrypted.to_bytes()).unwrap();
        let decrypted = ctx.decrypt(&encrypted);
        assert!(matches!(decrypted, Err(super::Error::Serialization(_))));
    }

    #[test]
    fn byte_serialization_roundtrips() {
        let secret = String::from("shhh");
        let ctx = TestContext::new();
        let encrypted = ctx.encrypt(&secret).unwrap();
        let bytes = encrypted.to_bytes();
        let encrypted = Encrypted::<String>::from_bytes_unchecked(&bytes).unwrap();
        let decrypted = ctx.decrypt(&encrypted).unwrap();
        assert_eq!(decrypted, secret);
    }

    #[test]
    fn deserializing_garbage_bytes_fails() {
        assert!(matches!(
            Encrypted::<String>::from_bytes_unchecked(&[33]),
            Err(super::Error::Encryption(
                KeyManagementError::ByteConversionError { .. }
            ))
        ));
    }
}