use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
use tls_codec::{Serialize, TlsSerialize, TlsSize, VLByteSlice};
use crate::{
binary_tree::array_representation::LeafNodeIndex, ciphersuite::HpkePublicKey,
error::LibraryError,
};
use super::{node::parent_node::ParentNode, LeafNode};
#[derive(TlsSerialize, TlsSize)]
pub(super) struct ParentHashInput<'a> {
public_key: &'a HpkePublicKey,
parent_hash: VLByteSlice<'a>,
original_sibling_tree_hash: VLByteSlice<'a>,
}
impl<'a> ParentHashInput<'a> {
pub(super) fn new(
public_key: &'a HpkePublicKey,
parent_hash: &'a [u8],
original_sibling_tree_hash: &'a [u8],
) -> Self {
Self {
public_key,
parent_hash: VLByteSlice(parent_hash),
original_sibling_tree_hash: VLByteSlice(original_sibling_tree_hash),
}
}
pub(super) fn hash(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
) -> Result<Vec<u8>, LibraryError> {
let payload = self
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
crypto
.hash(ciphersuite.hash_algorithm(), &payload)
.map_err(LibraryError::unexpected_crypto_error)
}
}
#[derive(TlsSerialize, TlsSize)]
#[repr(u8)]
enum NodeType<'a> {
#[tls_codec(discriminant = 1)]
Leaf(LeafNodeHashInput<'a>),
#[tls_codec(discriminant = 2)]
Parent(ParentNodeHashInput<'a>),
}
#[derive(TlsSerialize, TlsSize)]
pub(super) struct TreeHashInput<'a> {
node_type: NodeType<'a>,
}
impl<'a> TreeHashInput<'a> {
pub(super) fn new_leaf(leaf_index: &'a LeafNodeIndex, leaf_node: Option<&'a LeafNode>) -> Self {
Self {
node_type: NodeType::Leaf(LeafNodeHashInput {
leaf_index,
leaf_node,
}),
}
}
pub(super) fn new_parent(
parent_node: Option<&'a ParentNode>,
left_hash: VLByteSlice<'a>,
right_hash: VLByteSlice<'a>,
) -> Self {
Self {
node_type: NodeType::Parent(ParentNodeHashInput {
parent_node,
left_hash,
right_hash,
}),
}
}
pub(super) fn hash(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
) -> Result<Vec<u8>, LibraryError> {
let payload = self
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
crypto
.hash(ciphersuite.hash_algorithm(), &payload)
.map_err(LibraryError::unexpected_crypto_error)
}
}
#[derive(TlsSerialize, TlsSize)]
struct LeafNodeHashInput<'a> {
leaf_index: &'a LeafNodeIndex,
leaf_node: Option<&'a LeafNode>,
}
#[derive(TlsSerialize, TlsSize)]
struct ParentNodeHashInput<'a> {
parent_node: Option<&'a ParentNode>,
left_hash: VLByteSlice<'a>,
right_hash: VLByteSlice<'a>,
}