use maidsafe_utilities::serialisation::{deserialise, serialise};
use safe_crypto::{PublicEncryptKey, SecretEncryptKey, SharedSecretKey};
use serde::de::DeserializeOwned;
use serde::Serialize;
const ENCRYPTED_U32_LEN: usize = 52;
const SERIALIZED_U32_LEN: usize = 4;
#[derive(Clone, Debug)]
pub enum EncryptContext {
Null,
Authenticated { shared_key: SharedSecretKey },
AnonymousEncrypt {
their_pk: PublicEncryptKey,
},
}
impl Default for EncryptContext {
fn default() -> Self {
Self::null()
}
}
impl EncryptContext {
pub fn null() -> Self {
EncryptContext::Null
}
pub fn authenticated(shared_key: SharedSecretKey) -> Self {
EncryptContext::Authenticated { shared_key }
}
pub fn anonymous_encrypt(their_pk: PublicEncryptKey) -> Self {
EncryptContext::AnonymousEncrypt { their_pk }
}
pub fn encrypt<T: Serialize>(&self, msg: &T) -> ::Res<Vec<u8>> {
Ok(match *self {
EncryptContext::Null => serialise(msg)?,
EncryptContext::Authenticated { ref shared_key } => shared_key.encrypt(msg)?,
EncryptContext::AnonymousEncrypt { ref their_pk } => {
their_pk.anonymously_encrypt(msg)?
}
})
}
pub fn encrypted_size_len(&self) -> usize {
match *self {
EncryptContext::Null => SERIALIZED_U32_LEN,
EncryptContext::Authenticated { .. } => ENCRYPTED_U32_LEN,
EncryptContext::AnonymousEncrypt { .. } => ENCRYPTED_U32_LEN,
}
}
}
#[derive(Clone, Debug)]
pub enum DecryptContext {
Null,
Authenticated { shared_key: SharedSecretKey },
AnonymousDecrypt {
our_pk: PublicEncryptKey,
our_sk: SecretEncryptKey,
},
}
impl Default for DecryptContext {
fn default() -> Self {
Self::null()
}
}
impl DecryptContext {
pub fn null() -> Self {
DecryptContext::Null
}
pub fn authenticated(shared_key: SharedSecretKey) -> Self {
DecryptContext::Authenticated { shared_key }
}
pub fn anonymous_decrypt(our_pk: PublicEncryptKey, our_sk: SecretEncryptKey) -> Self {
DecryptContext::AnonymousDecrypt { our_pk, our_sk }
}
pub fn decrypt<T>(&self, msg: &[u8]) -> ::Res<T>
where
T: Serialize + DeserializeOwned,
{
Ok(match *self {
DecryptContext::Null => deserialise(msg)?,
DecryptContext::Authenticated { ref shared_key } => shared_key.decrypt(msg)?,
DecryptContext::AnonymousDecrypt {
ref our_pk,
ref our_sk,
} => our_sk.anonymously_decrypt(msg, our_pk)?,
})
}
pub fn encrypted_size_len(&self) -> usize {
match *self {
DecryptContext::Null => SERIALIZED_U32_LEN,
DecryptContext::Authenticated { .. } => ENCRYPTED_U32_LEN,
DecryptContext::AnonymousDecrypt { .. } => ENCRYPTED_U32_LEN,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamcrest2::prelude::*;
use safe_crypto::gen_encrypt_keypair;
use std::u32::MAX as MAX_U32;
use DEFAULT_MAX_PAYLOAD_SIZE;
mod encrypt_context {
use super::*;
#[test]
fn encrypt_always_returns_constant_length_byte_array_for_4_byte_input_with_anonymous_encryption(
) {
let (pk, _sk) = gen_encrypt_keypair();
let enc_ctx = EncryptContext::anonymous_encrypt(pk);
for size in &[0u32, 25000, DEFAULT_MAX_PAYLOAD_SIZE as u32, MAX_U32] {
let encrypted = unwrap!(enc_ctx.encrypt(&size));
assert_that!(&encrypted, len(ENCRYPTED_U32_LEN));
}
}
#[test]
fn encrypt_always_returns_constant_length_byte_array_for_4_byte_input_with_authenticated_encryption(
) {
let (_, sk1) = gen_encrypt_keypair();
let (pk2, _) = gen_encrypt_keypair();
let enc_ctx = EncryptContext::authenticated(sk1.shared_secret(&pk2));
for size in &[0u32, 25000, DEFAULT_MAX_PAYLOAD_SIZE as u32, MAX_U32] {
let encrypted = unwrap!(enc_ctx.encrypt(&size));
assert_that!(&encrypted, len(ENCRYPTED_U32_LEN));
}
}
}
#[test]
fn null_encryption_serializes_and_deserializes_data() {
let enc_ctx = EncryptContext::null();
let dec_ctx = DecryptContext::null();
let encrypted = unwrap!(enc_ctx.encrypt(b"test123"));
let decrypted: [u8; 7] = unwrap!(dec_ctx.decrypt(&encrypted[..]));
assert_eq!(&decrypted, b"test123");
}
#[test]
fn authenticated_encryption_encrypts_and_decrypts_data() {
let (pk1, sk1) = gen_encrypt_keypair();
let (pk2, sk2) = gen_encrypt_keypair();
let enc_ctx = EncryptContext::authenticated(sk1.shared_secret(&pk2));
let dec_ctx = DecryptContext::authenticated(sk2.shared_secret(&pk1));
let encrypted = unwrap!(enc_ctx.encrypt(b"test123"));
let decrypted: [u8; 7] = unwrap!(dec_ctx.decrypt(&encrypted[..]));
assert_eq!(&decrypted, b"test123");
}
#[test]
fn anonymous_encryption() {
let (pk1, sk1) = gen_encrypt_keypair();
let enc_ctx = EncryptContext::anonymous_encrypt(pk1);
let dec_ctx = DecryptContext::anonymous_decrypt(pk1, sk1);
let encrypted = unwrap!(enc_ctx.encrypt(b"test123"));
let decrypted: [u8; 7] = unwrap!(dec_ctx.decrypt(&encrypted[..]));
assert_eq!(&decrypted, b"test123");
}
}