use std::collections::HashMap;
use crate::{
CipherSuite,
error::{Result, SframeError},
header::KeyId,
key::{DecryptionKey, KeyStore},
};
use super::{ratcheting_base_key::RatchetingBaseKey, ratcheting_key_id::RatchetingKeyId};
pub struct RatchetingKeyStore {
keys: HashMap<RatchetingKeyId, RatchetingKeys>,
n_ratchet_bits: u8,
}
impl RatchetingKeyStore {
pub fn new(n_ratchet_bits: u8) -> Self {
Self {
n_ratchet_bits,
keys: Default::default(),
}
}
pub fn insert<K, M>(
&mut self,
cipher_suite: CipherSuite,
key_id: K,
key_material: M,
) -> Result<()>
where
K: Into<KeyId>,
M: AsRef<[u8]>,
{
let key_id = RatchetingKeyId::from_key_id(key_id.into(), self.n_ratchet_bits);
let sframe_key = DecryptionKey::derive_from(cipher_suite, key_id, &key_material)?;
let base_key = RatchetingBaseKey::ratchet_forward(key_id, key_material, cipher_suite)?;
self.keys.insert(
key_id,
RatchetingKeys {
base_key,
dec_key: sframe_key,
},
);
Ok(())
}
pub fn remove<K>(&mut self, key_id: K) -> bool
where
K: Into<KeyId>,
{
let key_id = RatchetingKeyId::from_key_id(key_id.into(), self.n_ratchet_bits);
self.keys.remove(&key_id).is_some()
}
pub fn get<K>(&self, key_id: K) -> Option<&RatchetingKeys>
where
K: Into<KeyId>,
{
let key_id = RatchetingKeyId::from_key_id(key_id.into(), self.n_ratchet_bits);
self.keys.get(&key_id)
}
pub fn try_ratchet<K>(&mut self, key_id: K) -> Result<u64>
where
K: Into<KeyId>,
{
let mut key_id = RatchetingKeyId::from_key_id(key_id, self.n_ratchet_bits);
let keys = self
.keys
.get_mut(&key_id)
.ok_or(SframeError::MissingDecryptionKey(key_id.into()))?;
key_id.inc_ratchet_step();
let current_ratchet_step = keys.base_key.key_id().ratchet_step();
let max_ratchet_value = 1 << self.n_ratchet_bits;
let step_diff = (key_id
.ratchet_step()
.overflowing_sub(current_ratchet_step)
.0)
% max_ratchet_value;
let next_base_key = (0..step_diff)
.map(|_| keys.base_key.next_base_key())
.next_back();
if let Some(next_base_key) = next_base_key {
let (next_key_id, next_material) = next_base_key?;
keys.dec_key = DecryptionKey::derive_from(
keys.dec_key.cipher_suite(),
next_key_id,
next_material,
)?;
}
Ok(step_diff)
}
}
pub struct RatchetingKeys {
pub base_key: RatchetingBaseKey,
pub dec_key: DecryptionKey,
}
impl KeyStore for RatchetingKeyStore {
fn get_key<K>(&self, key_id: K) -> Option<&DecryptionKey>
where
K: Into<KeyId>,
{
let key_id = RatchetingKeyId::from_key_id(key_id, self.n_ratchet_bits);
self.keys.get(&key_id).map(|key| &key.dec_key)
}
}
#[cfg(test)]
mod test {
use super::RatchetingKeyStore;
use crate::{
CipherSuite, header::KeyId, key::KeyStore, ratchet::ratcheting_key_id::RatchetingKeyId,
};
use pretty_assertions::assert_eq;
const N_RATCHET_BITS: u8 = 8;
const KEY_MATERIAL: &[u8] = b"SECRET";
const GENERATION: u64 = 42;
#[test]
fn expands_and_ratchets_forward_on_insert() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let key_id = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
key_store
.insert(CipherSuite::AesGcm256Sha512, key_id, KEY_MATERIAL)
.unwrap();
let keys = key_store.get(key_id);
assert!(keys.is_some());
let keys = keys.unwrap();
assert_eq!(keys.base_key.key_id().generation(), GENERATION);
assert_eq!(keys.base_key.key_id().ratchet_step(), 1);
let key_id_without_ratcheting_step = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
assert_eq!(
KeyId::from(key_id_without_ratcheting_step),
keys.dec_key.key_id()
);
}
#[test]
fn returns_none_for_unknown_key_on_get() {
let key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let key_id = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
let keys = key_store.get(key_id);
assert!(keys.is_none());
}
#[test]
fn removes_key() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let key_id = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
key_store
.insert(CipherSuite::AesGcm128Sha256, key_id, KEY_MATERIAL)
.unwrap();
let was_removed = key_store.remove(key_id);
let keys = key_store.get(key_id);
assert!(was_removed);
assert!(keys.is_none());
}
#[test]
fn returns_err_for_unknown_key_on_ratcheting_get() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let key_id = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
let keys = key_store.try_ratchet(key_id);
assert!(keys.is_err());
}
#[test]
fn inserts_and_gets_key() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let key_id = RatchetingKeyId::new(GENERATION, N_RATCHET_BITS);
key_store
.insert(CipherSuite::AesGcm256Sha512, key_id, KEY_MATERIAL)
.unwrap();
let dec_key = key_store.get_key(key_id).unwrap();
assert_eq!(KeyId::from(key_id), dec_key.key_id());
}
#[test]
fn inserts_key_and_ratches_forward_if_needed() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let mut key_id = RatchetingKeyId::new(42u8, N_RATCHET_BITS);
key_store
.insert(CipherSuite::AesGcm256Sha512, key_id, KEY_MATERIAL)
.unwrap();
let ratchet_steps = key_store.try_ratchet(key_id).unwrap();
assert_eq!(ratchet_steps, 0);
let first_key = key_store.get_key(key_id).unwrap().clone();
let first_key_id = RatchetingKeyId::from_key_id(first_key.key_id(), N_RATCHET_BITS);
assert_eq!(first_key_id, key_id);
assert_eq!(first_key_id.ratchet_step(), 0);
key_id.inc_ratchet_step();
let ratchet_steps = key_store.try_ratchet(key_id).unwrap();
assert_eq!(ratchet_steps, 1);
let second_key = key_store.get_key(key_id).unwrap();
assert_ne!(first_key.secret(), second_key.secret());
let second_key_id = RatchetingKeyId::from_key_id(second_key.key_id(), N_RATCHET_BITS);
assert_eq!(second_key_id.ratchet_step(), 1);
assert_eq!(first_key_id, second_key_id);
}
#[test]
fn stores_ratcheted_key() {
let mut key_store = RatchetingKeyStore::new(N_RATCHET_BITS);
let mut key_id = RatchetingKeyId::new(42u8, N_RATCHET_BITS);
key_store
.insert(CipherSuite::AesGcm128Sha256, key_id, KEY_MATERIAL)
.unwrap();
key_id.inc_ratchet_step();
key_store.try_ratchet(key_id).unwrap();
let first_secret = key_store.get_key(key_id).unwrap().clone();
key_store.try_ratchet(key_id).unwrap();
let second_secret = key_store.get_key(key_id).unwrap().clone();
assert_eq!(first_secret, second_secret);
}
#[test]
fn ratchets_on_ratcheting_step_overflow() {
let n_ratchet_bits = 1;
let mut key_store = RatchetingKeyStore::new(n_ratchet_bits);
let mut key_id = RatchetingKeyId::new(42u8, n_ratchet_bits);
key_store
.insert(CipherSuite::AesGcm256Sha512, key_id, KEY_MATERIAL)
.unwrap();
key_id.inc_ratchet_step();
key_store.try_ratchet(key_id).unwrap();
let first_secret = key_store.get_key(key_id).unwrap().clone();
key_id.inc_ratchet_step();
key_store.try_ratchet(key_id).unwrap();
let second_secret = key_store.get_key(key_id).unwrap().clone();
assert_ne!(first_secret, second_secret);
}
}