use rand_core::{CryptoRng, RngCore};
use std::cmp;
use std::error::Error;
use std::fmt;
use std::{collections::HashMap, hash::Hash};
const MAX_SKIP: usize = 1000;
pub type Counter = u32;
pub struct DoubleRatchet<CP: CryptoProvider> {
dhs: CP::KeyPair,
dhr: Option<CP::PublicKey>,
rk: CP::RootKey,
cks: Option<CP::ChainKey>,
ckr: Option<CP::ChainKey>,
ns: Counter,
nr: Counter,
pn: Counter,
mkskipped: KeyStore<CP>,
}
impl<CP> fmt::Debug for DoubleRatchet<CP>
where
CP: CryptoProvider,
CP::KeyPair: fmt::Debug,
CP::PublicKey: fmt::Debug,
CP::RootKey: fmt::Debug,
CP::ChainKey: fmt::Debug,
CP::MessageKey: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"DoubleRatchet {{ dhs: {:?}, dhr: {:?}, rk: {:?}, cks: {:?}, ckr: {:?}, ns: {:?}, \
nr: {:?}, pn: {:?}, mkskipped: {:?} }}",
self.dhs,
self.dhr,
self.rk,
self.cks,
self.ckr,
self.ns,
self.nr,
self.pn,
self.mkskipped
)
}
}
impl<CP: CryptoProvider> DoubleRatchet<CP> where {
pub fn new_alice<R: CryptoRng + RngCore>(
shared_secret: &CP::RootKey,
them: CP::PublicKey,
initial_receive: Option<CP::ChainKey>,
rng: &mut R,
) -> Self {
let dhs = CP::KeyPair::new(rng);
let (rk, cks) = CP::kdf_rk(shared_secret, &CP::diffie_hellman(&dhs, &them));
Self {
dhs,
dhr: Some(them),
rk,
cks: Some(cks),
ckr: initial_receive,
ns: 0,
nr: 0,
pn: 0,
mkskipped: KeyStore::new(),
}
}
pub fn new_bob(
shared_secret: CP::RootKey,
us: CP::KeyPair,
initial_send: Option<CP::ChainKey>,
) -> Self {
Self {
dhs: us,
dhr: None,
rk: shared_secret,
cks: initial_send,
ckr: None,
ns: 0,
nr: 0,
pn: 0,
mkskipped: KeyStore::new(),
}
}
pub fn try_ratchet_encrypt<R: CryptoRng + RngCore>(
&mut self,
plaintext: &[u8],
associated_data: &[u8],
rng: &mut R,
) -> Result<(Header<CP::PublicKey>, Vec<u8>), EncryptUninit> {
if self.can_encrypt() {
Ok(self.ratchet_encrypt(plaintext, associated_data, rng))
} else {
Err(EncryptUninit)
}
}
pub fn ratchet_encrypt<R: CryptoRng + RngCore>(
&mut self,
plaintext: &[u8],
associated_data: &[u8],
rng: &mut R,
) -> (Header<CP::PublicKey>, Vec<u8>) {
let (h, mk) = self.ratchet_send_chain(rng);
let pt = CP::encrypt(&mk, plaintext, &Self::concat(&h, associated_data));
(h, pt)
}
fn can_encrypt(&self) -> bool {
self.cks.is_some() || self.dhr.is_some()
}
fn ratchet_send_chain<R: CryptoRng + RngCore>(
&mut self,
rng: &mut R,
) -> (Header<CP::PublicKey>, CP::MessageKey) {
if self.cks.is_none() {
let dhr = self
.dhr
.as_ref()
.expect("not yet initialized for encryption");
self.dhs = CP::KeyPair::new(rng);
let (rk, cks) = CP::kdf_rk(&self.rk, &CP::diffie_hellman(&self.dhs, dhr));
self.rk = rk;
self.cks = Some(cks);
self.pn = self.ns;
self.ns = 0;
}
let h = Header {
dh: self.dhs.public().clone(),
n: self.ns,
pn: self.pn,
};
let (cks, mk) = CP::kdf_ck(self.cks.as_ref().unwrap());
self.cks = Some(cks);
self.ns += 1;
(h, mk)
}
pub fn ratchet_decrypt(
&mut self,
header: &Header<CP::PublicKey>,
ciphertext: &[u8],
associated_data: &[u8],
) -> Result<Vec<u8>, DecryptError> {
let (diff, pt) =
self.try_decrypt(header, ciphertext, &Self::concat(&header, associated_data))?;
self.update(diff, header);
Ok(pt)
}
fn try_decrypt(
&self,
h: &Header<CP::PublicKey>,
ct: &[u8],
ad: &[u8],
) -> Result<(Diff<CP>, Vec<u8>), DecryptError> {
use Diff::*;
if let Some(mk) = self.mkskipped.get(&h.dh, h.n) {
Ok((OldKey, CP::decrypt(mk, ct, ad)?))
} else if self.dhr.as_ref() == Some(&h.dh) {
let (ckr, mut mks) =
Self::skip_message_keys(self.ckr.as_ref().unwrap(), self.get_current_skip(h)?);
let mk = mks.pop().unwrap();
Ok((CurrentChain(ckr, mks), CP::decrypt(&mk, ct, ad)?))
} else {
let (rk, ckr) = CP::kdf_rk(&self.rk, &CP::diffie_hellman(&self.dhs, &h.dh));
let (ckr, mut mks) = Self::skip_message_keys(&ckr, self.get_next_skip(h)?);
let mk = mks.pop().unwrap();
Ok((NextChain(rk, ckr, mks), CP::decrypt(&mk, ct, ad)?))
}
}
fn get_current_skip(&self, h: &Header<CP::PublicKey>) -> Result<usize, DecryptError> {
let skip =
h.n.checked_sub(self.nr)
.ok_or(DecryptError::MessageKeyNotFound)? as usize;
if MAX_SKIP < skip {
Err(DecryptError::SkipTooLarge)
} else if self.mkskipped.can_store(skip) {
Ok(skip)
} else {
Err(DecryptError::StorageFull)
}
}
fn get_next_skip(&self, h: &Header<CP::PublicKey>) -> Result<usize, DecryptError> {
let prev_skip =
h.pn.checked_sub(self.nr)
.ok_or(DecryptError::MessageKeyNotFound)? as usize;
let skip = h.n as usize;
if MAX_SKIP < cmp::max(prev_skip, skip) {
Err(DecryptError::SkipTooLarge)
} else if self
.mkskipped
.can_store((prev_skip + skip).saturating_sub(1))
{
Ok(skip)
} else {
Err(DecryptError::StorageFull)
}
}
fn update(&mut self, diff: Diff<CP>, h: &Header<CP::PublicKey>) {
use Diff::*;
match diff {
OldKey => self.mkskipped.remove(&h.dh, h.n),
CurrentChain(ckr, mks) => {
self.mkskipped.extend(&h.dh, self.nr, mks);
self.ckr = Some(ckr);
self.nr = h.n + 1;
}
NextChain(rk, ckr, mks) => {
if self.ckr.is_some() && self.nr < h.pn {
let ckr = self.ckr.as_ref().unwrap();
let (_, prev_mks) = Self::skip_message_keys(ckr, (h.pn - self.nr - 1) as usize);
let dhr = self.dhr.as_ref().unwrap();
self.mkskipped.extend(dhr, self.nr, prev_mks);
}
self.dhr = Some(h.dh.clone());
self.rk = rk;
self.cks = None;
self.ckr = Some(ckr);
self.nr = h.n + 1;
self.mkskipped.extend(&h.dh, 0, mks);
}
}
}
fn skip_message_keys(ckr: &CP::ChainKey, skip: usize) -> (CP::ChainKey, Vec<CP::MessageKey>) {
let mut mks = Vec::with_capacity(skip + 1);
let (mut ckr, mk) = CP::kdf_ck(&ckr);
mks.push(mk);
for _ in 0..skip {
let cm = CP::kdf_ck(&ckr);
ckr = cm.0;
mks.push(cm.1);
}
(ckr, mks)
}
fn concat(h: &Header<CP::PublicKey>, ad: &[u8]) -> Vec<u8> {
let mut v = Vec::new();
v.extend_from_slice(ad);
h.extend_bytes_into(&mut v);
v
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Header<PublicKey> {
pub dh: PublicKey,
pub n: Counter,
pub pn: Counter,
}
impl<PK: AsRef<[u8]>> Header<PK> {
fn extend_bytes_into(&self, v: &mut Vec<u8>) {
v.extend_from_slice(self.dh.as_ref());
v.extend_from_slice(&self.n.to_be_bytes());
v.extend_from_slice(&self.pn.to_be_bytes());
}
}
pub trait CryptoProvider {
type PublicKey: AsRef<[u8]> + Clone + Eq + Hash;
type KeyPair: KeyPair<PublicKey = Self::PublicKey>;
type SharedSecret;
type RootKey;
type ChainKey;
type MessageKey;
fn diffie_hellman(us: &Self::KeyPair, them: &Self::PublicKey) -> Self::SharedSecret;
fn kdf_rk(
root_key: &Self::RootKey,
shared_secret: &Self::SharedSecret,
) -> (Self::RootKey, Self::ChainKey);
fn kdf_ck(chain_key: &Self::ChainKey) -> (Self::ChainKey, Self::MessageKey);
fn encrypt(key: &Self::MessageKey, plaintext: &[u8], associated_data: &[u8]) -> Vec<u8>;
fn decrypt(
key: &Self::MessageKey,
ciphertext: &[u8],
associated_data: &[u8],
) -> Result<Vec<u8>, DecryptError>;
}
pub trait KeyPair {
type PublicKey;
fn new<R: CryptoRng + RngCore>(rng: &mut R) -> Self;
fn public(&self) -> &Self::PublicKey;
}
const MKS_CAPACITY: usize = 2000;
struct KeyStore<CP: CryptoProvider>(HashMap<CP::PublicKey, HashMap<Counter, CP::MessageKey>>);
impl<CP> fmt::Debug for KeyStore<CP>
where
CP: CryptoProvider,
CP::PublicKey: fmt::Debug,
CP::MessageKey: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "KeyStore({:?})", self.0)
}
}
impl<CP: CryptoProvider> KeyStore<CP> {
fn new() -> Self {
Self(HashMap::new())
}
fn get(&self, dh: &CP::PublicKey, n: Counter) -> Option<&CP::MessageKey> {
self.0.get(dh)?.get(&n)
}
fn can_store(&self, n: usize) -> bool {
let current: usize = self.0.values().map(|v| v.len()).sum();
current + n <= MKS_CAPACITY
}
fn extend(&mut self, dh: &CP::PublicKey, n: Counter, mks: Vec<CP::MessageKey>) {
let values = (n..).zip(mks.into_iter());
if let Some(v) = self.0.get_mut(dh) {
v.extend(values);
} else {
self.0.insert(dh.clone(), values.collect());
}
}
fn remove(&mut self, dh: &CP::PublicKey, n: Counter) {
debug_assert!(self.0.contains_key(dh));
let hm = self.0.get_mut(dh).unwrap();
debug_assert!(hm.contains_key(&n));
if hm.len() == 1 {
self.0.remove(dh);
} else {
hm.remove(&n);
}
}
}
enum Diff<CP: CryptoProvider> {
OldKey,
CurrentChain(CP::ChainKey, Vec<CP::MessageKey>),
NextChain(CP::RootKey, CP::ChainKey, Vec<CP::MessageKey>),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct EncryptUninit;
impl Error for EncryptUninit {}
impl fmt::Display for EncryptUninit {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Encrypt not yet initialized (you must receive a message first)"
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DecryptError {
DecryptFailure,
MessageKeyNotFound,
SkipTooLarge,
StorageFull,
}
impl Error for DecryptError {}
impl fmt::Display for DecryptError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use DecryptError::*;
match self {
DecryptFailure => write!(f, "Error during verify-decrypting"),
MessageKeyNotFound => {
write!(f, "Could not find the message key required for decryption")
}
SkipTooLarge => write!(f, "Header message counter is too large"),
StorageFull => write!(f, "Storage for skipped messages is full"),
}
}
}
#[cfg(feature = "test")]
#[allow(unused)]
#[allow(missing_docs)]
pub mod mock {
pub type DoubleRatchet = super::DoubleRatchet<CryptoProvider>;
pub struct CryptoProvider;
impl super::CryptoProvider for CryptoProvider {
type KeyPair = KeyPair;
type PublicKey = PublicKey;
type SharedSecret = u8;
type RootKey = [u8; 2];
type ChainKey = [u8; 3];
type MessageKey = [u8; 3];
fn diffie_hellman(us: &KeyPair, them: &PublicKey) -> u8 {
us.0[0].wrapping_add(them.0[0])
}
fn kdf_rk(rk: &[u8; 2], s: &u8) -> ([u8; 2], [u8; 3]) {
([rk[0], *s], [rk[0], rk[1], 0])
}
fn kdf_ck(ck: &[u8; 3]) -> ([u8; 3], [u8; 3]) {
([ck[0], ck[1], ck[2].wrapping_add(1)], *ck)
}
fn encrypt(mk: &[u8; 3], pt: &[u8], ad: &[u8]) -> Vec<u8> {
let mut ct = Vec::from(&mk[..]);
ct.extend_from_slice(pt);
ct.extend_from_slice(ad);
ct
}
fn decrypt(mk: &[u8; 3], ct: &[u8], ad: &[u8]) -> Result<Vec<u8>, super::DecryptError> {
if ct.len() < 3 + ad.len() || ct[..3] != mk[..] || !ct.ends_with(ad) {
Err(super::DecryptError::DecryptFailure)
} else {
Ok(Vec::from(&ct[3..ct.len() - ad.len()]))
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct PublicKey([u8; 1]);
impl AsRef<[u8]> for PublicKey {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[derive(Debug)]
pub struct KeyPair([u8; 1], PublicKey);
impl super::KeyPair for KeyPair {
type PublicKey = PublicKey;
#[allow(clippy::cast_possible_truncation)]
fn new<R: rand_core::CryptoRng + rand_core::RngCore>(rng: &mut R) -> Self {
let n = rng.next_u32() as u8;
Self([n], PublicKey([n + 1]))
}
fn public(&self) -> &PublicKey {
&self.1
}
}
#[derive(Default)]
pub struct Rng(u64);
impl rand_core::RngCore for Rng {
fn next_u64(&mut self) -> u64 {
self.0 += 1;
self.0
}
#[allow(clippy::cast_possible_truncation)]
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
fn fill_bytes(&mut self, out: &mut [u8]) {
rand_core::impls::fill_bytes_via_next(self, out);
}
fn try_fill_bytes(&mut self, out: &mut [u8]) -> Result<(), rand_core::Error> {
self.fill_bytes(out);
Ok(())
}
}
impl super::CryptoRng for Rng {}
}
#[cfg(test)]
mod tests {
use super::*;
type DR = DoubleRatchet<mock::CryptoProvider>;
fn asymmetric_setup(rng: &mut mock::Rng) -> (DR, DR) {
let secret = [42, 0];
let pair = mock::KeyPair::new(rng);
let pubkey = pair.public().clone();
let alice = DR::new_alice(&secret, pubkey, None, rng);
let bob = DR::new_bob(secret, pair, None);
(alice, bob)
}
fn symmetric_setup(rng: &mut mock::Rng) -> (DR, DR) {
let secret = [42, 0];
let ck_init = [42, 0, 0];
let pair = mock::KeyPair::new(rng);
let pubkey = pair.public().clone();
let alice = DR::new_alice(&secret, pubkey, Some(ck_init), rng);
let bob = DR::new_bob(secret, pair, Some(ck_init));
(alice, bob)
}
#[test]
fn test_asymmetric_setup() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (pt_a, ad_a) = (b"Hi Bobby", b"A2B");
let (pt_b, ad_b) = (b"What's up Al?", b"B2A");
let (h_a, ct_a) = alice.ratchet_encrypt(pt_a, ad_a, &mut rng);
assert_eq!(
Err(EncryptUninit),
bob.try_ratchet_encrypt(pt_b, ad_b, &mut rng)
);
assert_eq!(
Ok(Vec::from(&pt_a[..])),
bob.ratchet_decrypt(&h_a, &ct_a, ad_a)
);
let (h_b, ct_b) = bob.ratchet_encrypt(pt_b, ad_b, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_b[..])),
alice.ratchet_decrypt(&h_b, &ct_b, ad_b)
);
}
#[test]
fn test_symmetric_setup() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = symmetric_setup(&mut rng);
let (pt_a, ad_a) = (b"Hi Bobby", b"A2B");
let (pt_b, ad_b) = (b"What's up Al?", b"B2A");
let (h_a, ct_a) = alice.ratchet_encrypt(pt_a, ad_a, &mut rng);
let (h_b, ct_b) = bob.ratchet_encrypt(pt_b, ad_b, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_a[..])),
bob.ratchet_decrypt(&h_a, &ct_a, ad_a)
);
assert_eq!(
Ok(Vec::from(&pt_b[..])),
alice.ratchet_decrypt(&h_b, &ct_b, ad_b)
);
}
#[test]
fn symmetric_out_of_order() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let pt_a_0 = b"Hi Bobby";
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(pt_a_0, ad_a, &mut rng);
for _ in 1..9 {
alice.ratchet_encrypt(b"hello?", ad_a, &mut rng); }
let pt_a_9 = b"are you there?";
let (h_a_9, ct_a_9) = alice.ratchet_encrypt(pt_a_9, ad_a, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_a_9[..])),
bob.ratchet_decrypt(&h_a_9, &ct_a_9, ad_a)
);
assert_eq!(
Ok(Vec::from(&pt_a_0[..])),
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a)
);
let pt_b_0 = b"Yes I'm here";
let (h_b_0, ct_b_0) = bob.ratchet_encrypt(pt_b_0, ad_b, &mut rng);
for _ in 1..9 {
bob.ratchet_encrypt(b"why?", ad_b, &mut rng); }
let pt_b_9 = b"Tell me why!!!";
let (h_b_9, ct_b_9) = bob.ratchet_encrypt(pt_b_9, ad_b, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_b_9[..])),
alice.ratchet_decrypt(&h_b_9, &ct_b_9, ad_b)
);
assert_eq!(
Ok(Vec::from(&pt_b_0[..])),
alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b)
);
}
#[test]
fn dh_out_of_order() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let pt_a_0 = b"Good day Robert";
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(pt_a_0, ad_a, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_a_0[..])),
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a)
);
let pt_a_1 = b"Do you like Rust?";
let (h_a_1, ct_a_1) = alice.ratchet_encrypt(pt_a_1, ad_a, &mut rng);
let pt_b_0 = b"Salutations Allison";
let (h_b_0, ct_b_0) = bob.ratchet_encrypt(pt_b_0, ad_b, &mut rng);
let pt_b_1 = b"How is your day going?";
let (h_b_1, ct_b_1) = bob.ratchet_encrypt(pt_b_1, ad_b, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_b_1[..])),
alice.ratchet_decrypt(&h_b_1, &ct_b_1, ad_b)
);
let pt_a_2 = b"My day is fine.";
let (h_a_2, ct_a_2) = alice.ratchet_encrypt(pt_a_2, ad_a, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_a_2[..])),
bob.ratchet_decrypt(&h_a_2, &ct_a_2, ad_a)
);
assert_eq!(
Ok(Vec::from(&pt_a_1[..])),
bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a)
);
let pt_b_2 = b"Yes I like Rust";
let (h_b_2, ct_b_2) = bob.ratchet_encrypt(pt_b_2, ad_b, &mut rng);
assert_eq!(
Ok(Vec::from(&pt_b_2[..])),
alice.ratchet_decrypt(&h_b_2, &ct_b_2, ad_b)
);
assert_eq!(
Ok(Vec::from(&pt_b_0[..])),
alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b)
);
}
#[test]
#[should_panic(expected = "not yet initialized for encryption")]
fn encrypt_error() {
let mut rng = mock::Rng::default();
let (_alice, mut bob) = asymmetric_setup(&mut rng);
assert_eq!(
Err(EncryptUninit),
bob.try_ratchet_encrypt(b"", b"", &mut rng)
);
bob.ratchet_encrypt(b"", b"", &mut rng);
}
#[test]
fn decrypt_failure() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
let mut ct_a_0_err = ct_a_0.clone();
ct_a_0_err[2] ^= 0x80;
let mut h_a_0_err = h_a_0.clone();
h_a_0_err.pn = 1;
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_0, &ct_a_0_err, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_0_err, &ct_a_0, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_b)
);
let (h_a_1, ct_a_1) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a).unwrap();
let (h_a_2, ct_a_2) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
let mut h_a_2_err = h_a_2.clone();
h_a_2_err.pn += 1;
let mut ct_a_2_err = ct_a_2.clone();
ct_a_2_err[0] ^= 0x04;
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2, &ct_a_2_err, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2_err, &ct_a_2, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2, &ct_a_2, ad_b)
);
let (h_b, ct_b) = bob.ratchet_encrypt(b"Hi Alice", ad_b, &mut rng);
alice.ratchet_decrypt(&h_b, &ct_b, ad_b).unwrap();
let (h_a_3, ct_a_3) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a_3, &ct_a_3, ad_a).unwrap();
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2, &ct_a_2_err, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2_err, &ct_a_2, ad_a)
);
assert_eq!(
Err(DecryptError::DecryptFailure),
bob.ratchet_decrypt(&h_a_2, &ct_a_2, ad_b)
);
}
#[test]
fn double_sending() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(b"Whatever", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a).unwrap();
assert!(bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a).is_err());
let (h_b_0, ct_b_0) = bob.ratchet_encrypt(b"Whatever", ad_b, &mut rng);
alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b).unwrap();
assert!(alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b).is_err());
let (h_a_1, ct_a_1) = alice.ratchet_encrypt(b"Whatever", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a).unwrap();
assert!(bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a).is_err());
let (h_b_1, ct_b_1) = bob.ratchet_encrypt(b"Whatever", ad_b, &mut rng);
alice.ratchet_decrypt(&h_b_1, &ct_b_1, ad_b).unwrap();
assert!(alice.ratchet_decrypt(&h_b_1, &ct_b_1, ad_b).is_err());
assert!(bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a).is_err());
assert!(alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b).is_err());
}
#[test]
fn invalid_header() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a).unwrap();
let (h_b_0, ct_b_0) = bob.ratchet_encrypt(b"Hi Alice", ad_b, &mut rng);
alice.ratchet_decrypt(&h_b_0, &ct_b_0, ad_b).unwrap();
let (mut h_a_1, ct_a_1) = alice.ratchet_encrypt(b"I will lie to you now", ad_a, &mut rng);
assert_eq!(h_a_1.pn, 1);
h_a_1.pn = 0;
assert!(bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a).is_err());
}
#[test]
fn skip_too_large() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let (ad_a, ad_b) = (b"A2B", b"B2A");
let (h_a_0, ct_a_0) = alice.ratchet_encrypt(b"Hi Bob", ad_a, &mut rng);
for _ in 0..=MAX_SKIP {
alice.ratchet_encrypt(b"Not sending this", ad_a, &mut rng);
}
let (h_a_1, ct_a_1) = alice.ratchet_encrypt(b"n > MAXSKIP", ad_a, &mut rng);
assert_eq!(
Err(DecryptError::SkipTooLarge),
bob.ratchet_decrypt(&h_a_1, &ct_a_1, ad_a)
);
bob.ratchet_decrypt(&h_a_0, &ct_a_0, ad_a).unwrap();
let (h_b, ct_b) = bob.ratchet_encrypt(b"Hi Alice", ad_b, &mut rng);
alice.ratchet_decrypt(&h_b, &ct_b, ad_b).unwrap();
let (h_a_2, ct_a_2) = alice.ratchet_encrypt(b"pn > MAXSKIP", ad_a, &mut rng);
assert_eq!(
Err(DecryptError::SkipTooLarge),
bob.ratchet_decrypt(&h_a_2, &ct_a_2, ad_a)
);
}
#[test]
fn storage_full() {
let mut rng = mock::Rng::default();
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
let ad_a = b"A2B";
let mut stored = 0;
while stored < MKS_CAPACITY {
for _ in 0..cmp::min(MAX_SKIP, MKS_CAPACITY - stored) {
alice.ratchet_encrypt(b"Not sending this", ad_a, &mut rng);
}
let (h_a, ct_a) = alice.ratchet_encrypt(b"Hello Bob", ad_a, &mut rng);
bob.ratchet_decrypt(&h_a, &ct_a, ad_a).unwrap();
stored += MAX_SKIP;
dbg!(&bob.mkskipped.0.values().map(|hm| hm.len()).sum::<usize>());
}
alice.ratchet_encrypt(b"Bob can't store this key anymore", ad_a, &mut rng);
let (h_a, ct_a) = alice.ratchet_encrypt(b"Gotcha, Bob!", ad_a, &mut rng);
assert_eq!(
Err(DecryptError::StorageFull),
bob.ratchet_decrypt(&h_a, &ct_a, ad_a)
);
}
#[test]
fn cannot_crash_other() {
let mut rng = mock::Rng::default();
let (ad_a, ad_b) = (b"A2B", b"B2A");
let (mut alice, mut bob) = symmetric_setup(&mut rng);
alice.pn = 10;
bob.pn = 10;
let (h_a, ct_a) = alice.ratchet_encrypt(b"not important", ad_a, &mut rng);
let (h_b, ct_b) = bob.ratchet_encrypt(b"not important", ad_b, &mut rng);
let _ = alice.ratchet_decrypt(&h_b, &ct_b, ad_b);
let _ = bob.ratchet_decrypt(&h_a, &ct_a, ad_a);
let (mut alice, mut bob) = asymmetric_setup(&mut rng);
alice.pn = 10;
let (h_a, ct_a) = alice.ratchet_encrypt(b"not important", ad_a, &mut rng);
let _ = bob.ratchet_decrypt(&h_a, &ct_a, ad_a);
bob.pn = 10;
let (h_b, ct_b) = bob.ratchet_encrypt(b"not important", ad_b, &mut rng);
let _ = alice.ratchet_decrypt(&h_b, &ct_b, ad_b);
}
}