use core::mem::size_of;
use ssb_crypto::{Keypair, PublicKey};
use zerocopy::{AsBytes, FromBytes, LayoutVerified};
use ssb_crypto::ephemeral::{
derive_shared_secret_pk, derive_shared_secret_sk, generate_ephemeral_keypair, EphPublicKey,
};
use ssb_crypto::secretbox::{Hmac, Key, Nonce};
const MAX_RECIPIENTS: usize = 8;
#[cfg(feature = "sodium")]
pub fn init() {
ssb_crypto::sodium::init();
}
#[derive(AsBytes, FromBytes)]
#[repr(C, packed)]
pub struct MsgKey {
recp_count: u8,
key: Key,
}
impl MsgKey {
fn zeroed() -> MsgKey {
MsgKey {
recp_count: 0,
key: Key([0; 32]),
}
}
pub fn as_array(&self) -> [u8; 33] {
let mut out = [0; 33];
out.copy_from_slice(self.as_bytes());
out
}
}
#[derive(AsBytes, FromBytes)]
#[repr(C, packed)]
struct BoxedKey {
hmac: Hmac,
msg_key: [u8; 33],
}
pub fn encrypted_size(text: &[u8], recps: &[PublicKey]) -> usize {
size_of::<Nonce>() + size_of::<EphPublicKey>() + recps.len() * size_of::<BoxedKey>() + size_of::<Hmac>() + text.len()
}
fn set_prefix<'a>(buf: &'a mut [u8], prefix: &[u8]) -> &'a mut [u8] {
let (p, rest) = buf.split_at_mut(prefix.len());
p.copy_from_slice(prefix);
rest
}
pub fn encrypt(plaintext: &[u8], recipients: &[PublicKey]) -> Vec<u8> {
let mut out = vec![0; encrypted_size(plaintext, recipients)];
encrypt_into(plaintext, recipients, &mut out);
out
}
pub fn encrypt_into(plaintext: &[u8], recipients: &[PublicKey], mut out: &mut [u8]) {
if recipients.len() > MAX_RECIPIENTS || recipients.len() == 0 {
panic!(
"Number of recipients must be less than {}, greater than 0",
MAX_RECIPIENTS
);
}
assert!(out.len() >= encrypted_size(plaintext, recipients));
let nonce = Nonce::generate();
let (eph_pk, eph_sk) = generate_ephemeral_keypair();
let mkey = MsgKey {
recp_count: recipients.len() as u8,
key: Key::generate(),
};
let mut rest = set_prefix(&mut out, nonce.as_bytes());
let rest = set_prefix(&mut rest, eph_pk.as_bytes());
let (keys, rest) = rest.split_at_mut(recipients.len() * size_of::<BoxedKey>());
let mut keychunks = keys.chunks_mut(size_of::<BoxedKey>());
for pk in recipients {
let kkey = Key(derive_shared_secret_pk(&eph_sk, pk).unwrap().0);
let mut msg_key = mkey.as_array();
let hmac = kkey.seal(&mut msg_key, &nonce);
keychunks
.next()
.unwrap()
.copy_from_slice(BoxedKey { hmac, msg_key }.as_bytes());
}
let (hmac_buf, text) = rest.split_at_mut(Hmac::SIZE);
text.copy_from_slice(plaintext);
let hmac = mkey.key.seal(text, &nonce);
hmac_buf.copy_from_slice(hmac.as_bytes());
}
const BOXED_KEY_SIZE_BYTES: usize = 32 + 1 + 16;
pub fn decrypt(cyphertext: &[u8], keypair: &Keypair) -> Option<Vec<u8>> {
let msg_key = decrypt_key(cyphertext, keypair)?;
decrypt_body(cyphertext, &msg_key)
}
pub fn decrypt_key(cyphertext: &[u8], keypair: &Keypair) -> Option<MsgKey> {
let nonce = Nonce::from_slice(&cyphertext[0..24])?;
let eph_pk = EphPublicKey::from_slice(&cyphertext[24..56])?;
let key_key = Key(derive_shared_secret_sk(&keypair.secret, &eph_pk)?.0);
let mut msg_key = MsgKey::zeroed();
&cyphertext[56..]
.chunks_exact(BOXED_KEY_SIZE_BYTES)
.take(MAX_RECIPIENTS)
.find(|b| key_key.open_attached_into(b, &nonce, msg_key.as_bytes_mut()))?;
Some(msg_key)
}
pub fn decrypt_body(cyphertext: &[u8], msg_key: &MsgKey) -> Option<Vec<u8>> {
let nonce = Nonce::from_slice(&cyphertext[0..24])?;
let boxed_msg = &cyphertext[(56 + BOXED_KEY_SIZE_BYTES * msg_key.recp_count as usize)..];
let mut out = vec![0; boxed_msg.len() - Hmac::SIZE];
if msg_key.key.open_attached_into(&boxed_msg, &nonce, &mut out) {
Some(out)
} else {
None
}
}
pub fn decrypt_body_with_key_bytes(cyphertext: &[u8], msg_key: &[u8]) -> Option<Vec<u8>> {
let key = LayoutVerified::<&[u8], MsgKey>::new(msg_key)
.unwrap()
.into_ref();
decrypt_body(cyphertext, key)
}
#[cfg(test)]
mod tests {
use crate::*;
use base64::decode;
use serde_derive::{Deserialize, Serialize};
use serde_json;
use std::error::Error;
use std::fs::File;
use std::path::Path;
use ssb_crypto::Keypair;
#[derive(Serialize, Deserialize)]
struct Key {
secret: String,
public: String,
}
#[derive(Serialize, Deserialize)]
struct TestData {
cypher_text: String,
msg: String,
keys: Vec<Key>,
}
fn read_test_data_from_file<P: AsRef<Path>>(path: P) -> Result<TestData, Box<dyn Error>> {
let file = File::open(path)?;
let t = serde_json::from_reader(file)?;
Ok(t)
}
#[test]
fn simple() {
let msg: [u8; 3] = [0, 1, 2];
let alice = Keypair::generate();
let bob = Keypair::generate();
let recps = [alice.public, bob.public];
let cypher = encrypt(&msg, &recps);
let alice_result = decrypt(&cypher, &alice);
let bob_result = decrypt(&cypher, &bob);
assert_eq!(alice_result.unwrap(), msg);
assert_eq!(bob_result.unwrap(), msg);
}
#[test]
fn is_js_compatible() {
let test_data = read_test_data_from_file("./test/simple.json").unwrap();
let cypher = decode(&test_data.cypher_text).unwrap();
let keys: Vec<Keypair> = test_data
.keys
.iter()
.map(|key| Keypair::from_base64(&key.secret).unwrap())
.collect();
let alice = &keys[0];
let bob = &keys[1];
assert_eq!(decrypt(&cypher, &alice).unwrap(), test_data.msg.as_bytes());
assert_eq!(decrypt(&cypher, &bob).unwrap(), test_data.msg.as_bytes());
}
#[test]
#[should_panic]
fn passing_too_many_recipients_panics() {
let msg: [u8; 3] = [0, 1, 2];
let alice = Keypair::generate();
let recps = vec![alice.public; 9];
let _ = encrypt(&msg, &recps);
}
#[test]
#[should_panic]
fn passing_zero_recipients_panics() {
let msg: [u8; 3] = [0, 1, 2];
let recps: [PublicKey; 0] = [];
let _ = encrypt(&msg, &recps);
}
}