use openmls_traits::crypto::OpenMlsCrypto;
use std::collections::VecDeque;
use openmls_traits::types::Ciphersuite;
use crate::ciphersuite::{AeadNonce, *};
use crate::tree::secret_tree::*;
use super::*;
pub(crate) type Generation = u32;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SenderRatchetConfiguration {
out_of_order_tolerance: Generation,
maximum_forward_distance: Generation,
}
impl SenderRatchetConfiguration {
pub fn new(out_of_order_tolerance: Generation, maximum_forward_distance: Generation) -> Self {
Self {
out_of_order_tolerance,
maximum_forward_distance,
}
}
pub fn out_of_order_tolerance(&self) -> Generation {
self.out_of_order_tolerance
}
pub fn maximum_forward_distance(&self) -> Generation {
self.maximum_forward_distance
}
}
impl Default for SenderRatchetConfiguration {
fn default() -> Self {
Self::new(5, 1000)
}
}
pub(crate) type RatchetKeyMaterial = (AeadKey, AeadNonce);
#[derive(Serialize, Deserialize)]
#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
pub(crate) enum SenderRatchet {
EncryptionRatchet(RatchetSecret),
DecryptionRatchet(DecryptionRatchet),
}
impl SenderRatchet {
#[cfg(test)]
pub(crate) fn generation(&self) -> Generation {
match self {
SenderRatchet::EncryptionRatchet(enc_ratchet) => enc_ratchet.generation(),
SenderRatchet::DecryptionRatchet(dec_ratchet) => dec_ratchet.generation(),
}
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
pub(crate) struct RatchetSecret {
secret: Secret,
generation: Generation,
}
impl RatchetSecret {
pub(crate) fn initial_ratchet_secret(secret: Secret) -> Self {
Self {
secret,
generation: 0,
}
}
pub(crate) fn generation(&self) -> Generation {
self.generation
}
pub(crate) fn ratchet_forward(
&mut self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
) -> Result<(Generation, RatchetKeyMaterial), SecretTreeError> {
log::trace!("Ratcheting forward in generation {}.", self.generation);
log_crypto!(trace, " with secret {:x?}", self.secret);
if self.generation == u32::MAX {
return Err(SecretTreeError::RatchetTooLong);
}
let nonce = derive_tree_secret(
ciphersuite,
&self.secret,
"nonce",
self.generation,
ciphersuite.aead_nonce_length(),
crypto,
)?;
let key = derive_tree_secret(
ciphersuite,
&self.secret,
"key",
self.generation,
ciphersuite.aead_key_length(),
crypto,
)?;
self.secret = derive_tree_secret(
ciphersuite,
&self.secret,
"secret",
self.generation,
ciphersuite.hash_length(),
crypto,
)?;
let generation = self.generation;
self.generation += 1;
Ok((
generation,
(
AeadKey::from_secret(key, ciphersuite),
AeadNonce::from_secret(nonce),
),
))
}
#[cfg(test)]
pub(crate) fn set_generation(&mut self, generation: Generation) {
self.generation = generation
}
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
pub struct DecryptionRatchet {
past_secrets: VecDeque<Option<RatchetKeyMaterial>>,
ratchet_head: RatchetSecret,
}
impl DecryptionRatchet {
pub(crate) fn new(secret: Secret) -> Self {
Self {
past_secrets: VecDeque::new(),
ratchet_head: RatchetSecret::initial_ratchet_secret(secret),
}
}
fn prune_past_secrets(&mut self, configuration: &SenderRatchetConfiguration) {
self.past_secrets
.truncate(configuration.out_of_order_tolerance() as usize)
}
pub(crate) fn generation(&self) -> Generation {
self.ratchet_head.generation()
}
#[cfg(test)]
pub(crate) fn ratchet_secret_mut(&mut self) -> &mut RatchetSecret {
&mut self.ratchet_head
}
pub(crate) fn secret_for_decryption(
&mut self,
ciphersuite: Ciphersuite,
crypto: &impl OpenMlsCrypto,
generation: Generation,
configuration: &SenderRatchetConfiguration,
) -> Result<RatchetKeyMaterial, SecretTreeError> {
log::debug!("secret_for_decryption");
if self.generation() < u32::MAX - configuration.maximum_forward_distance()
&& generation > self.generation() + configuration.maximum_forward_distance()
{
return Err(SecretTreeError::TooDistantInTheFuture);
}
if generation < self.generation()
&& (self.generation() - generation) > configuration.out_of_order_tolerance()
{
log::error!(" Generation is too far in the past (broke out of order tolerance ({}) {generation} < {}).", configuration.out_of_order_tolerance(), self.generation());
return Err(SecretTreeError::TooDistantInThePast);
}
if generation >= self.generation() {
for _ in 0..(generation - self.generation()) {
let ratchet_secrets = {
self.ratchet_head
.ratchet_forward(crypto, ciphersuite)
.map(|(_, key_material)| key_material)
}?;
self.past_secrets.push_front(Some(ratchet_secrets));
}
let ratchet_secrets = {
self.ratchet_head
.ratchet_forward(crypto, ciphersuite)
.map(|(_, key_material)| key_material)
}?;
self.past_secrets.push_front(None);
self.prune_past_secrets(configuration);
Ok(ratchet_secrets)
} else {
let window_index = ((self.generation() - generation) as i32) - 1;
let index = if window_index >= 0 {
window_index as usize
} else {
log::error!(" Generation is too far in the past (not in the window).");
return Err(SecretTreeError::TooDistantInThePast);
};
self.past_secrets
.get_mut(index)
.ok_or(SecretTreeError::IndexOutOfBounds)?
.take()
.ok_or(SecretTreeError::SecretReuseError)
}
}
}