use super::leaf_node::LeafNode;
use crate::client::MlsError;
use crate::crypto::HpkePublicKey;
use crate::tree_kem::math as tree_math;
use crate::tree_kem::parent_hash::ParentHash;
use alloc::vec;
use alloc::vec::Vec;
use core::hash::Hash;
use core::ops::{Deref, DerefMut};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use tree_math::{CopathNode, TreeIndex};
#[cfg(feature = "serde")]
use mls_rs_core::error::IntoAnyError;
pub(crate) const MAX_LEAF_INDEX: u32 = (1 << 24) - 1;
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct Parent {
pub public_key: HpkePublicKey,
pub parent_hash: ParentHash,
pub unmerged_leaves: Vec<LeafIndex>,
}
#[derive(Clone, Copy, Debug, Ord, PartialEq, PartialOrd, Hash, Eq, MlsSize, MlsEncode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct LeafIndex(u32);
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for LeafIndex {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let value = u.int_in_range(0..=MAX_LEAF_INDEX)?;
Ok(LeafIndex(value))
}
}
impl TryFrom<u32> for LeafIndex {
type Error = MlsError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if value > MAX_LEAF_INDEX {
return Err(MlsError::InvalidTreeIndex);
}
Ok(Self(value))
}
}
impl LeafIndex {
pub(crate) fn from_node_index_unchecked(index: NodeIndex) -> Self {
LeafIndex(index >> 1)
}
pub(crate) fn unchecked(value: u32) -> Self {
Self(value)
}
pub(crate) fn next_unchecked(&self) -> Self {
LeafIndex(self.0 + 1)
}
}
impl Deref for LeafIndex {
type Target = u32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<&LeafIndex> for NodeIndex {
fn from(leaf_index: &LeafIndex) -> Self {
leaf_index.0 * 2
}
}
impl From<LeafIndex> for NodeIndex {
fn from(leaf_index: LeafIndex) -> Self {
leaf_index.0 * 2
}
}
impl MlsDecode for LeafIndex {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
let val = u32::mls_decode(reader)?;
LeafIndex::try_from(val).map_err(|_| mls_rs_codec::Error::Custom(6))
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for LeafIndex {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let val = u32::deserialize(deserializer)?;
LeafIndex::try_from(val).map_err(|e| serde::de::Error::custom(e.into_any_error()))
}
}
pub(crate) type NodeIndex = u32;
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[allow(clippy::large_enum_variant)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub(crate) enum Node {
Leaf(LeafNode) = 1u8,
Parent(Parent) = 2u8,
}
impl Node {
pub fn public_key(&self) -> &HpkePublicKey {
match self {
Node::Parent(p) => &p.public_key,
Node::Leaf(l) => &l.public_key,
}
}
}
impl From<Parent> for Option<Node> {
fn from(p: Parent) -> Self {
Node::from(p).into()
}
}
impl From<LeafNode> for Option<Node> {
fn from(l: LeafNode) -> Self {
Node::from(l).into()
}
}
impl From<Parent> for Node {
fn from(p: Parent) -> Self {
Node::Parent(p)
}
}
impl From<LeafNode> for Node {
fn from(l: LeafNode) -> Self {
Node::Leaf(l)
}
}
pub(crate) trait NodeTypeResolver {
fn as_parent(&self) -> Result<&Parent, MlsError>;
fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>;
fn as_leaf(&self) -> Result<&LeafNode, MlsError>;
fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>;
fn as_non_empty(&self) -> Result<&Node, MlsError>;
}
impl NodeTypeResolver for Option<Node> {
fn as_parent(&self) -> Result<&Parent, MlsError> {
self.as_ref()
.and_then(|n| match n {
Node::Parent(p) => Some(p),
Node::Leaf(_) => None,
})
.ok_or(MlsError::ExpectedNode)
}
fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError> {
self.as_mut()
.and_then(|n| match n {
Node::Parent(p) => Some(p),
Node::Leaf(_) => None,
})
.ok_or(MlsError::ExpectedNode)
}
fn as_leaf(&self) -> Result<&LeafNode, MlsError> {
self.as_ref()
.and_then(|n| match n {
Node::Parent(_) => None,
Node::Leaf(l) => Some(l),
})
.ok_or(MlsError::ExpectedNode)
}
fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError> {
self.as_mut()
.and_then(|n| match n {
Node::Parent(_) => None,
Node::Leaf(l) => Some(l),
})
.ok_or(MlsError::ExpectedNode)
}
fn as_non_empty(&self) -> Result<&Node, MlsError> {
self.as_ref().ok_or(MlsError::UnexpectedEmptyNode)
}
}
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct NodeVec(Vec<Option<Node>>);
impl From<Vec<Option<Node>>> for NodeVec {
fn from(x: Vec<Option<Node>>) -> Self {
NodeVec(x)
}
}
impl Deref for NodeVec {
type Target = Vec<Option<Node>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for NodeVec {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl NodeVec {
#[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
pub fn occupied_leaf_count(&self) -> u32 {
self.non_empty_leaves().count() as u32
}
pub fn total_leaf_count(&self) -> u32 {
(self.len() as u32 / 2 + 1).next_power_of_two()
}
#[inline]
pub fn borrow_node(&self, index: NodeIndex) -> Result<&Option<Node>, MlsError> {
Ok(self.get(self.validate_index(index)?).unwrap_or(&None))
}
fn validate_index(&self, index: NodeIndex) -> Result<usize, MlsError> {
if (index as usize) >= self.len().next_power_of_two() {
Err(MlsError::InvalidNodeIndex(index))
} else {
Ok(index as usize)
}
}
#[cfg(test)]
fn empty_leaves(&mut self) -> impl Iterator<Item = (LeafIndex, &mut Option<Node>)> {
self.iter_mut()
.step_by(2)
.enumerate()
.filter(|(_, n)| n.is_none())
.map(|(i, n)| (LeafIndex::unchecked(i as u32), n))
}
pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
self.leaves()
.enumerate()
.filter_map(|(i, l)| l.map(|l| (LeafIndex::unchecked(i as u32), l)))
}
pub fn non_empty_parents(&self) -> impl Iterator<Item = (NodeIndex, &Parent)> + '_ {
self.iter()
.enumerate()
.skip(1)
.step_by(2)
.map(|(i, n)| (i as NodeIndex, n))
.filter_map(|(i, n)| n.as_parent().ok().map(|p| (i, p)))
}
pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
self.iter().step_by(2).map(|n| n.as_leaf().ok())
}
pub fn direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>> {
NodeIndex::from(index).direct_copath(&self.total_leaf_count())
}
pub fn filtered(&self, index: LeafIndex) -> Result<Vec<bool>, MlsError> {
Ok(NodeIndex::from(index)
.direct_copath(&self.total_leaf_count())
.into_iter()
.map(|cp| self.is_resolution_empty(cp.copath))
.collect())
}
#[inline]
pub fn is_blank(&self, index: NodeIndex) -> Result<bool, MlsError> {
self.borrow_node(index).map(|n| n.is_none())
}
#[inline]
pub fn is_leaf(&self, index: NodeIndex) -> bool {
index % 2 == 0
}
pub fn blank_leaf_node(&mut self, leaf_index: LeafIndex) -> Result<LeafNode, MlsError> {
let node_index = self.validate_index(leaf_index.into())?;
match self.get_mut(node_index).and_then(Option::take) {
Some(Node::Leaf(l)) => Ok(l),
_ => Err(MlsError::RemovingNonExistingMember),
}
}
pub fn blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError> {
for i in self.direct_copath(leaf) {
if let Some(n) = self.get_mut(i.path as usize) {
*n = None
}
}
Ok(())
}
pub fn trim(&mut self) {
while self.last() == Some(&None) {
self.pop();
}
}
pub fn borrow_as_parent(&self, node_index: NodeIndex) -> Result<&Parent, MlsError> {
self.borrow_node(node_index).and_then(|n| n.as_parent())
}
pub fn borrow_as_parent_mut(&mut self, node_index: NodeIndex) -> Result<&mut Parent, MlsError> {
let index = self.validate_index(node_index)?;
self.get_mut(index)
.ok_or(MlsError::InvalidNodeIndex(node_index))?
.as_parent_mut()
}
pub fn borrow_as_leaf_mut(&mut self, index: LeafIndex) -> Result<&mut LeafNode, MlsError> {
let node_index = NodeIndex::from(index);
let index = self.validate_index(node_index)?;
self.get_mut(index)
.ok_or(MlsError::InvalidNodeIndex(node_index))?
.as_leaf_mut()
}
pub fn borrow_as_leaf(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
let node_index = NodeIndex::from(index);
self.borrow_node(node_index).and_then(|n| n.as_leaf())
}
pub fn borrow_or_fill_node_as_parent(
&mut self,
node_index: NodeIndex,
public_key: &HpkePublicKey,
) -> Result<&mut Parent, MlsError> {
let index = self.validate_index(node_index)?;
while self.len() <= index {
self.push(None);
}
self.get_mut(index)
.ok_or(MlsError::InvalidNodeIndex(node_index))
.and_then(|n| {
if n.is_none() {
*n = Parent {
public_key: public_key.clone(),
parent_hash: ParentHash::empty(),
unmerged_leaves: vec![],
}
.into();
}
n.as_parent_mut()
})
}
pub fn get_resolution_index(&self, index: NodeIndex) -> Result<Vec<NodeIndex>, MlsError> {
let mut indexes = vec![index];
let mut resolution = vec![];
while let Some(index) = indexes.pop() {
if let Some(Some(node)) = self.get(index as usize) {
resolution.push(index);
if let Node::Parent(p) = node {
resolution.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
}
} else if !index.is_leaf() {
indexes.push(index.right_unchecked());
indexes.push(index.left_unchecked());
}
}
Ok(resolution)
}
pub fn find_in_resolution(
&self,
index: NodeIndex,
to_find: Option<NodeIndex>,
) -> Option<usize> {
let mut indexes = vec![index];
let mut resolution_len = 0;
while let Some(index) = indexes.pop() {
if let Some(Some(node)) = self.get(index as usize) {
if Some(index) == to_find || to_find.is_none() {
return Some(resolution_len);
}
resolution_len += 1;
if let Node::Parent(p) = node {
indexes.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
}
} else if !index.is_leaf() {
indexes.push(index.right_unchecked());
indexes.push(index.left_unchecked());
}
}
None
}
pub fn is_resolution_empty(&self, index: NodeIndex) -> bool {
self.find_in_resolution(index, None).is_none()
}
pub(crate) fn next_empty_leaf(&self, start: LeafIndex) -> LeafIndex {
let mut n = NodeIndex::from(start) as usize;
while n < self.len() {
if self.0[n].is_none() {
return LeafIndex::from_node_index_unchecked(n as NodeIndex);
}
n += 2;
}
LeafIndex::from_node_index_unchecked(self.len() as NodeIndex + 1)
}
pub fn insert_leaf(&mut self, index: LeafIndex, leaf: LeafNode) {
let node_index = (*index as usize) << 1;
if node_index > self.len() {
self.push(None);
self.push(None);
} else if self.is_empty() {
self.push(None);
}
self.0[node_index] = Some(leaf.into());
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use super::*;
use crate::{
client::test_utils::TEST_CIPHER_SUITE, tree_kem::leaf_node::test_utils::get_basic_test_node,
};
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn get_test_node_vec() -> NodeVec {
let mut nodes = vec![None; 7];
nodes[0] = get_basic_test_node(TEST_CIPHER_SUITE, "A").await.into();
nodes[4] = get_basic_test_node(TEST_CIPHER_SUITE, "C").await.into();
nodes[5] = Parent {
public_key: b"CD".to_vec().into(),
parent_hash: ParentHash::empty(),
unmerged_leaves: vec![LeafIndex::unchecked(2)],
}
.into();
nodes[6] = get_basic_test_node(TEST_CIPHER_SUITE, "D").await.into();
NodeVec::from(nodes)
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
use crate::{
client::test_utils::TEST_CIPHER_SUITE,
tree_kem::{
leaf_node::test_utils::get_basic_test_node, node::test_utils::get_test_node_vec,
},
};
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn node_key_getters() {
let test_node_parent: Node = Parent {
public_key: b"pub".to_vec().into(),
parent_hash: ParentHash::empty(),
unmerged_leaves: vec![],
}
.into();
let test_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "B").await;
let test_node_leaf: Node = test_leaf.clone().into();
assert_eq!(test_node_parent.public_key().as_ref(), b"pub");
assert_eq!(test_node_leaf.public_key(), &test_leaf.public_key);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_empty_leaves() {
let mut test_vec = get_test_node_vec().await;
let mut test_vec_clone = get_test_node_vec().await;
let empty_leaves: Vec<(LeafIndex, &mut Option<Node>)> = test_vec.empty_leaves().collect();
assert_eq!(
[(LeafIndex::unchecked(1), &mut test_vec_clone[2])].as_ref(),
empty_leaves.as_slice()
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_direct_path() {
let test_vec = get_test_node_vec().await;
let expected = 0.direct_copath(&4);
let actual = test_vec.direct_copath(LeafIndex::unchecked(0));
assert_eq!(actual, expected);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_filtered_direct_path_co_path() {
let test_vec = get_test_node_vec().await;
let expected = [true, false];
let actual = test_vec.filtered(LeafIndex::unchecked(0)).unwrap();
assert_eq!(actual, expected);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_get_parent_node() {
let mut test_vec = get_test_node_vec().await;
assert!(test_vec.borrow_as_parent_mut(0).is_err());
assert!(test_vec
.borrow_as_parent_mut(test_vec.len() as u32)
.is_err());
let mut expected = Parent {
public_key: b"CD".to_vec().into(),
parent_hash: ParentHash::empty(),
unmerged_leaves: vec![LeafIndex::unchecked(2)],
};
assert_eq!(test_vec.borrow_as_parent_mut(5).unwrap(), &mut expected);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_get_resolution() {
let test_vec = get_test_node_vec().await;
let resolution_node_5 = test_vec.get_resolution_index(5).unwrap();
let resolution_node_2 = test_vec.get_resolution_index(2).unwrap();
let resolution_node_3 = test_vec.get_resolution_index(3).unwrap();
assert_eq!(&resolution_node_5, &[5, 4]);
assert!(resolution_node_2.is_empty());
assert_eq!(&resolution_node_3, &[0, 5, 4]);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_get_or_fill_existing() {
let mut test_vec = get_test_node_vec().await;
let mut test_vec2 = test_vec.clone();
let expected = test_vec[5].as_parent_mut().unwrap();
let actual = test_vec2
.borrow_or_fill_node_as_parent(5, &Vec::new().into())
.unwrap();
assert_eq!(actual, expected);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_get_or_fill_empty() {
let mut test_vec = get_test_node_vec().await;
let mut expected = Parent {
public_key: vec![0u8; 4].into(),
parent_hash: ParentHash::empty(),
unmerged_leaves: vec![],
};
let actual = test_vec
.borrow_or_fill_node_as_parent(1, &vec![0u8; 4].into())
.unwrap();
assert_eq!(actual, &mut expected);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_leaf_count() {
let test_vec = get_test_node_vec().await;
assert_eq!(test_vec.len(), 7);
assert_eq!(test_vec.occupied_leaf_count(), 3);
assert_eq!(
test_vec.non_empty_leaves().count(),
test_vec.occupied_leaf_count() as usize
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_total_leaf_count() {
let test_vec = get_test_node_vec().await;
assert_eq!(test_vec.occupied_leaf_count(), 3);
assert_eq!(test_vec.total_leaf_count(), 4);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn max_leaf_index() {
let test_index = LeafIndex::try_from(1).unwrap();
let serialized = test_index.mls_encode_to_vec().unwrap();
LeafIndex::mls_decode(&mut &*serialized).unwrap();
#[cfg(feature = "serde")]
{
let serialized = serde_json::to_string(&test_index).unwrap();
serde_json::from_str::<LeafIndex>(&serialized).unwrap();
}
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn max_leaf_index_failure() {
let res = LeafIndex::try_from(MAX_LEAF_INDEX + 1);
assert_matches!(res, Err(MlsError::InvalidTreeIndex));
let serialized = LeafIndex::unchecked(MAX_LEAF_INDEX + 1)
.mls_encode_to_vec()
.unwrap();
let res = LeafIndex::mls_decode(&mut &*serialized);
assert_matches!(res, Err(mls_rs_codec::Error::Custom(6)));
#[cfg(feature = "serde")]
{
let serialized =
serde_json::to_string(&LeafIndex::unchecked(MAX_LEAF_INDEX + 1)).unwrap();
let res: Result<LeafIndex, _> = serde_json::from_str(&serialized);
assert!(res.is_err())
}
}
}