use std::collections::HashSet;
use openmls_traits::crypto::OpenMlsCrypto;
use openmls_traits::types::Ciphersuite;
use serde::{Deserialize, Serialize};
use tls_codec::Serialize as TlsSerialize;
#[cfg(feature = "extensions-draft-08")]
use super::errors::ApplyAppDataUpdateError;
use super::PublicGroup;
use crate::{
binary_tree::{array_representation::TreeSize, LeafNodeIndex},
error::LibraryError,
extensions::Extensions,
framing::{mls_auth_content::AuthenticatedContent, public_message::InterimTranscriptHashInput},
group::GroupContext,
messages::{proposals::AddProposal, ConfirmationTag, EncryptedGroupSecrets},
schedule::{psk::PreSharedKeyId, CommitSecret, JoinerSecret},
treesync::{
diff::{StagedTreeSyncDiff, TreeSyncDiff},
errors::{ApplyUpdatePathError, TreeSyncFromNodesError},
node::{
encryption_keys::EncryptionKeyPair, leaf_node::LeafNode,
parent_node::PlainUpdatePathNode,
},
treekem::{DecryptPathParams, UpdatePath, UpdatePathNode},
RatchetTree, TreeSync,
},
};
pub(crate) mod apply_proposals;
pub(crate) mod compute_path;
pub(crate) struct PublicGroupDiff<'a> {
diff: TreeSyncDiff<'a>,
group_context: GroupContext,
interim_transcript_hash: Vec<u8>,
confirmation_tag: ConfirmationTag,
}
impl<'a> PublicGroupDiff<'a> {
pub(super) fn new(public_group: &'a PublicGroup) -> PublicGroupDiff<'a> {
Self {
diff: public_group.treesync().empty_diff(),
group_context: public_group.group_context().clone(),
interim_transcript_hash: public_group.interim_transcript_hash().to_vec(),
confirmation_tag: public_group.confirmation_tag().clone(),
}
}
pub(crate) fn into_staged_diff(
self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
) -> Result<StagedPublicGroupDiff, LibraryError> {
let staged_diff = self.diff.into_staged_diff(crypto, ciphersuite)?;
Ok(StagedPublicGroupDiff {
staged_diff,
group_context: self.group_context,
interim_transcript_hash: self.interim_transcript_hash,
confirmation_tag: self.confirmation_tag,
})
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn encrypt_group_secrets(
&self,
joiner_secret: &JoinerSecret,
invited_members: Vec<(LeafNodeIndex, AddProposal)>,
plain_path_option: Option<&[PlainUpdatePathNode]>,
presharedkeys: &[PreSharedKeyId],
encrypted_group_info: &[u8],
crypto: &impl OpenMlsCrypto,
leaf_index: LeafNodeIndex,
) -> Result<Vec<EncryptedGroupSecrets>, LibraryError> {
self.diff.encrypt_group_secrets(
joiner_secret,
invited_members,
plain_path_option,
presharedkeys,
encrypted_group_info,
crypto,
leaf_index,
)
}
pub(crate) fn tree_size(&self) -> TreeSize {
self.diff.tree_size()
}
pub(crate) fn export_ratchet_tree(&self) -> RatchetTree {
self.diff.export_ratchet_tree()
}
pub(crate) fn decrypt_path(
&self,
crypto: &impl OpenMlsCrypto,
owned_keys: &[&EncryptionKeyPair],
own_leaf_index: LeafNodeIndex,
sender_leaf_index: LeafNodeIndex,
update_path: &[UpdatePathNode],
exclusion_list: &HashSet<&LeafNodeIndex>,
) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), ApplyUpdatePathError> {
let params = DecryptPathParams {
update_path,
sender_leaf_index,
exclusion_list,
group_context: &self
.group_context()
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?,
};
self.diff.decrypt_path(
crypto,
self.group_context().ciphersuite(),
params,
owned_keys,
own_leaf_index,
)
}
pub(crate) fn leaf(&self, index: LeafNodeIndex) -> Option<&LeafNode> {
self.diff.leaf(index)
}
pub(crate) fn apply_received_update_path(
&mut self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
sender_leaf_index: LeafNodeIndex,
update_path: &UpdatePath,
) -> Result<(), ApplyUpdatePathError> {
self.diff
.apply_received_update_path(crypto, ciphersuite, sender_leaf_index, update_path)
}
pub(crate) fn update_interim_transcript_hash(
&mut self,
ciphersuite: Ciphersuite,
crypto: &impl OpenMlsCrypto,
confirmation_tag: ConfirmationTag,
) -> Result<(), LibraryError> {
let interim_transcript_hash = {
let input = InterimTranscriptHashInput::from(&confirmation_tag);
input.calculate_interim_transcript_hash(
crypto,
ciphersuite,
self.group_context.confirmed_transcript_hash(),
)?
};
self.confirmation_tag = confirmation_tag;
self.interim_transcript_hash = interim_transcript_hash;
Ok(())
}
pub(crate) fn update_group_context(
&mut self,
crypto: &impl OpenMlsCrypto,
extensions: Option<Extensions<GroupContext>>,
) -> Result<(), LibraryError> {
let new_tree_hash = self
.diff
.compute_tree_hashes(crypto, self.group_context().ciphersuite())?;
self.group_context.update_tree_hash(new_tree_hash);
self.group_context.increment_epoch();
if let Some(extensions) = extensions {
self.group_context.set_extensions(extensions);
}
Ok(())
}
pub(crate) fn update_confirmed_transcript_hash(
&mut self,
crypto: &impl OpenMlsCrypto,
commit_content: &AuthenticatedContent,
) -> Result<(), LibraryError> {
self.group_context.update_confirmed_transcript_hash(
crypto,
&self.interim_transcript_hash,
commit_content,
)
}
pub(crate) fn group_context(&self) -> &GroupContext {
&self.group_context
}
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
pub(crate) struct StagedPublicGroupDiff {
pub(super) staged_diff: StagedTreeSyncDiff,
pub(super) group_context: GroupContext,
pub(super) interim_transcript_hash: Vec<u8>,
pub(super) confirmation_tag: ConfirmationTag,
}
impl StagedPublicGroupDiff {
pub(crate) fn group_context(&self) -> &GroupContext {
&self.group_context
}
pub(crate) fn export_ratchet_tree(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
original_tree: RatchetTree,
) -> Result<RatchetTree, TreeSyncFromNodesError> {
let original_tree_sync = TreeSync::from_ratchet_tree(crypto, ciphersuite, original_tree)?;
Ok(self
.staged_diff
.export_ratchet_tree(original_tree_sync.tree()))
}
}