#[cfg(test)]
mod secret_tree_test;
use bytes::{BufMut, Bytes, BytesMut};
use std::fmt::{Display, Formatter};
use crate::crypto::{cipher_suite::*, provider::CryptoProvider};
use crate::framing::*;
use crate::utilities::error::*;
use crate::utilities::tree_math::*;
const RATCHET_LABEL_HANDSHAKE_STR: &str = "handshake";
const RATCHET_LABEL_APPLICATION_STR: &str = "application";
#[derive(Default, Debug, Copy, Clone, Eq, PartialEq)]
pub enum RatchetLabel {
#[default]
Handshake,
Application,
}
impl Display for RatchetLabel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
RatchetLabel::Handshake => write!(f, "{}", RATCHET_LABEL_HANDSHAKE_STR),
RatchetLabel::Application => write!(f, "{}", RATCHET_LABEL_APPLICATION_STR),
}
}
}
impl From<ContentType> for RatchetLabel {
fn from(content_type: ContentType) -> Self {
match content_type {
ContentType::Application => RatchetLabel::Application,
ContentType::Proposal | ContentType::Commit => RatchetLabel::Handshake,
}
}
}
#[derive(Default, Debug, Clone, Eq, PartialEq)]
pub struct SecretTree(pub(crate) Vec<Option<Bytes>>);
impl SecretTree {
pub fn new(
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
n: NumLeaves,
encryption_secret: &[u8],
) -> Result<Self> {
let mut tree = SecretTree(vec![None; n.width() as usize]);
tree.set(n.root(), encryption_secret.to_vec().into());
tree.derive_children(crypto_provider, cipher_suite, n.root())?;
Ok(tree)
}
fn derive_children(
&mut self,
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
x: NodeIndex,
) -> Result<()> {
let (l, r, ok) = x.children();
if !ok {
return Ok(());
}
let parent_secret = self
.get(x)
.ok_or(Error::InvalidParentNode)?
.as_ref()
.ok_or(Error::InvalidParentNode)?;
let nh = crypto_provider.hpke(cipher_suite).kdf_extract_size() as u16;
let left_secret =
crypto_provider.expand_with_label(cipher_suite, parent_secret, b"tree", b"left", nh)?;
let right_secret = crypto_provider.expand_with_label(
cipher_suite,
parent_secret,
b"tree",
b"right",
nh,
)?;
self.set(l, left_secret);
self.set(r, right_secret);
self.derive_children(crypto_provider, cipher_suite, l)?;
self.derive_children(crypto_provider, cipher_suite, r)?;
Ok(())
}
fn get(&self, ni: NodeIndex) -> Option<&Option<Bytes>> {
self.0.get(ni.0 as usize)
}
fn set(&mut self, ni: NodeIndex, secret: Bytes) {
if (ni.0 as usize) < self.0.len() {
self.0[ni.0 as usize] = Some(secret);
}
}
pub fn derive_ratchet_root(
&self,
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
ni: NodeIndex,
label: RatchetLabel,
) -> Result<RatchetSecret> {
let parent_secret = self
.get(ni)
.ok_or(Error::InvalidParentNode)?
.as_ref()
.ok_or(Error::InvalidParentNode)?;
let nh = crypto_provider.hpke(cipher_suite).kdf_extract_size() as u16;
let secret = crypto_provider.expand_with_label(
cipher_suite,
parent_secret,
label.to_string().as_bytes(),
&[],
nh,
)?;
Ok(RatchetSecret {
secret,
generation: 0,
})
}
}
#[derive(Default, Debug, Clone, Eq, PartialEq)]
pub struct RatchetSecret {
pub secret: Bytes,
pub generation: u32,
}
impl RatchetSecret {
pub fn derive_nonce(
&self,
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
) -> Result<Bytes> {
let nn = crypto_provider.hpke(cipher_suite).aead_nonce_size() as u16;
derive_tree_secret(
crypto_provider,
cipher_suite,
&self.secret,
b"nonce",
self.generation,
nn,
)
}
pub fn derive_key(
&self,
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
) -> Result<Bytes> {
let nk = crypto_provider.hpke(cipher_suite).aead_key_size() as u16;
derive_tree_secret(
crypto_provider,
cipher_suite,
&self.secret,
b"key",
self.generation,
nk,
)
}
pub fn derive_next(
&self,
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
) -> Result<RatchetSecret> {
let nh = crypto_provider.hpke(cipher_suite).kdf_extract_size() as u16;
let secret = derive_tree_secret(
crypto_provider,
cipher_suite,
&self.secret,
b"secret",
self.generation,
nh,
)?;
Ok(RatchetSecret {
secret,
generation: self.generation + 1,
})
}
}
pub fn derive_tree_secret(
crypto_provider: &impl CryptoProvider,
cipher_suite: CipherSuite,
secret: &[u8],
label: &[u8],
generation: u32,
length: u16,
) -> Result<Bytes> {
let mut buf = BytesMut::new();
buf.put_u32(generation);
let context = buf.freeze();
crypto_provider.expand_with_label(cipher_suite, secret, label, &context, length)
}