use std::collections::HashSet;
use openmls_traits::crypto::OpenMlsCrypto;
use openmls_traits::types::Ciphersuite;
use serde::{Deserialize, Serialize};
use tls_codec::VLByteSlice;
use crate::{
binary_tree::array_representation::{tree::TreeNode, LeafNodeIndex},
error::LibraryError,
};
use super::{hashes::TreeHashInput, LeafNode, Node, ParentNode};
pub(crate) enum TreeSyncNode {
Leaf(Box<TreeSyncLeafNode>),
Parent(Box<TreeSyncParentNode>),
}
impl From<Node> for TreeSyncNode {
fn from(node: Node) -> Self {
match node {
Node::LeafNode(leaf) => TreeSyncNode::Leaf(Box::new((*leaf).into())),
Node::ParentNode(parent) => TreeSyncNode::Parent(Box::new((*parent).into())),
}
}
}
impl From<TreeSyncNode> for Option<Node> {
fn from(tsn: TreeSyncNode) -> Self {
match tsn {
TreeSyncNode::Leaf(leaf) => (*leaf).into(),
TreeSyncNode::Parent(parent) => (*parent).into(),
}
}
}
impl From<TreeNode<TreeSyncLeafNode, TreeSyncParentNode>> for TreeSyncNode {
fn from(tree_node: TreeNode<TreeSyncLeafNode, TreeSyncParentNode>) -> Self {
match tree_node {
TreeNode::Leaf(leaf) => TreeSyncNode::Leaf(leaf),
TreeNode::Parent(parent) => TreeSyncNode::Parent(parent),
}
}
}
impl From<TreeSyncNode> for TreeNode<TreeSyncLeafNode, TreeSyncParentNode> {
fn from(tsn: TreeSyncNode) -> Self {
match tsn {
TreeSyncNode::Leaf(leaf) => TreeNode::Leaf(leaf),
TreeSyncNode::Parent(parent) => TreeNode::Parent(parent),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
pub(crate) struct TreeSyncLeafNode {
node: Option<LeafNode>,
}
impl TreeSyncLeafNode {
pub(in crate::treesync) fn blank() -> Self {
Self::default()
}
pub(in crate::treesync) fn node(&self) -> &Option<LeafNode> {
&self.node
}
pub(in crate::treesync) fn compute_tree_hash(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
leaf_index: LeafNodeIndex,
) -> Result<Vec<u8>, LibraryError> {
let hash_input = TreeHashInput::new_leaf(&leaf_index, self.node.as_ref());
let hash = hash_input.hash(crypto, ciphersuite)?;
Ok(hash)
}
}
impl From<LeafNode> for TreeSyncLeafNode {
fn from(node: LeafNode) -> Self {
Self { node: Some(node) }
}
}
impl From<LeafNode> for Box<TreeSyncLeafNode> {
fn from(node: LeafNode) -> Self {
Box::new(TreeSyncLeafNode { node: Some(node) })
}
}
impl From<TreeSyncLeafNode> for Option<Node> {
fn from(tsln: TreeSyncLeafNode) -> Self {
tsln.node.map(|n| Node::LeafNode(Box::new(n)))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
pub(crate) struct TreeSyncParentNode {
node: Option<ParentNode>,
}
impl TreeSyncParentNode {
pub(in crate::treesync) fn blank() -> Self {
Self::default()
}
pub(in crate::treesync) fn node(&self) -> &Option<ParentNode> {
&self.node
}
pub(in crate::treesync) fn node_mut(&mut self) -> &mut Option<ParentNode> {
&mut self.node
}
pub(in crate::treesync) fn compute_tree_hash(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
left_hash: Vec<u8>,
right_hash: Vec<u8>,
exclusion_list: &HashSet<&LeafNodeIndex>,
) -> Result<Vec<u8>, LibraryError> {
let hash = if exclusion_list.is_empty() {
TreeHashInput::new_parent(
self.node.as_ref(),
VLByteSlice(&left_hash),
VLByteSlice(&right_hash),
)
.hash(crypto, ciphersuite)?
} else if let Some(parent_node) = self.node.as_ref() {
let mut new_node = parent_node.clone();
let unmerged_leaves = new_node
.unmerged_leaves()
.iter()
.filter(|leaf| !exclusion_list.contains(leaf))
.cloned()
.collect();
new_node.set_unmerged_leaves(unmerged_leaves);
TreeHashInput::new_parent(
Some(&new_node),
VLByteSlice(&left_hash),
VLByteSlice(&right_hash),
)
.hash(crypto, ciphersuite)?
} else {
TreeHashInput::new_parent(None, VLByteSlice(&left_hash), VLByteSlice(&right_hash))
.hash(crypto, ciphersuite)?
};
Ok(hash)
}
}
impl From<ParentNode> for TreeSyncParentNode {
fn from(node: ParentNode) -> Self {
Self { node: Some(node) }
}
}
impl From<ParentNode> for Box<TreeSyncParentNode> {
fn from(node: ParentNode) -> Self {
Box::new(TreeSyncParentNode { node: Some(node) })
}
}
impl From<TreeSyncParentNode> for Option<Node> {
fn from(tspn: TreeSyncParentNode) -> Self {
tspn.node.map(|n| Node::ParentNode(Box::new(n)))
}
}