use openmls_traits::{signatures::Signer, types::Ciphersuite};
use tls_codec::Serialize;
#[cfg(feature = "extensions-draft-08")]
use crate::schedule::application_export_tree::ApplicationExportTree;
use crate::{
binary_tree::{array_representation::TreeSize, LeafNodeIndex},
credentials::CredentialWithKey,
error::LibraryError,
extensions::Extensions,
group::{
past_secrets::MessageSecretsStore, public_group::errors::PublicGroupBuildError,
GroupContext, GroupId, MlsGroup, MlsGroupCreateConfig, MlsGroupCreateConfigBuilder,
MlsGroupState, NewGroupError, PublicGroup, WireFormatPolicy,
},
key_packages::Lifetime,
schedule::{
psk::{load_psks, store::ResumptionPskStore, PskSecret},
EpochSecretsResult, InitSecret, JoinerSecret, KeySchedule, PreSharedKeyId,
},
storage::OpenMlsProvider,
tree::sender_ratchet::SenderRatchetConfiguration,
treesync::{
errors::LeafNodeValidationError,
node::leaf_node::{Capabilities, LeafNode},
},
};
#[derive(Default, Debug)]
pub struct MlsGroupBuilder {
group_id: Option<GroupId>,
mls_group_create_config_builder: MlsGroupCreateConfigBuilder,
replace_old_group: bool,
psk_ids: Vec<PreSharedKeyId>,
}
impl MlsGroupBuilder {
pub(super) fn new() -> Self {
Self::default()
}
pub fn with_group_id(mut self, group_id: GroupId) -> Self {
self.group_id = Some(group_id);
self
}
pub fn replace_old_group(mut self) -> Self {
self.replace_old_group = true;
self
}
pub fn build<Provider: OpenMlsProvider>(
self,
provider: &Provider,
signer: &impl Signer,
credential_with_key: CredentialWithKey,
) -> Result<MlsGroup, NewGroupError<Provider::StorageError>> {
self.build_internal(provider, signer, credential_with_key, None)
}
pub(super) fn build_internal<Provider: OpenMlsProvider>(
self,
provider: &Provider,
signer: &impl Signer,
credential_with_key: CredentialWithKey,
mls_group_create_config_option: Option<MlsGroupCreateConfig>,
) -> Result<MlsGroup, NewGroupError<Provider::StorageError>> {
let mls_group_create_config = mls_group_create_config_option
.unwrap_or_else(|| self.mls_group_create_config_builder.build());
let group_id = self
.group_id
.unwrap_or_else(|| GroupId::random(provider.rand()));
let ciphersuite = mls_group_create_config.ciphersuite;
if !self.replace_old_group
&& MlsGroup::load(provider.storage(), &group_id)
.map_err(NewGroupError::StorageError)?
.is_some()
{
return Err(NewGroupError::GroupAlreadyExists);
}
let (public_group_builder, commit_secret, leaf_keypair) =
PublicGroup::builder(group_id, ciphersuite, credential_with_key)
.with_group_context_extensions(
mls_group_create_config.group_context_extensions.clone(),
)
.with_leaf_node_extensions(mls_group_create_config.leaf_node_extensions.clone())
.with_lifetime(*mls_group_create_config.lifetime())
.with_capabilities(mls_group_create_config.capabilities.clone())
.get_secrets(provider, signer)
.map_err(|e| match e {
PublicGroupBuildError::LibraryError(e) => NewGroupError::LibraryError(e),
PublicGroupBuildError::InvalidExtensions(e) => e.into(),
})?;
let serialized_group_context = public_group_builder
.group_context()
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
let joiner_secret = JoinerSecret::new(
provider.crypto(),
ciphersuite,
commit_secret,
&InitSecret::random(ciphersuite, provider.rand())
.map_err(LibraryError::unexpected_crypto_error)?,
&serialized_group_context,
)
.map_err(LibraryError::unexpected_crypto_error)?;
let mut resumption_psk_store = ResumptionPskStore::new(32);
let psk_secret = load_psks(provider.storage(), &resumption_psk_store, &self.psk_ids)
.and_then(|psks| PskSecret::new(provider.crypto(), ciphersuite, psks))
.map_err(|e| {
log::debug!("Unexpected PSK error: {e:?}");
LibraryError::custom("Unexpected PSK error")
})?;
let mut key_schedule =
KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret)?;
key_schedule
.add_context(provider.crypto(), &serialized_group_context)
.map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
let EpochSecretsResult {
epoch_secrets,
#[cfg(feature = "extensions-draft-08")]
application_exporter,
} = key_schedule
.epoch_secrets(provider.crypto(), ciphersuite)
.map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
let (group_epoch_secrets, message_secrets) = epoch_secrets.split_secrets(
serialized_group_context,
TreeSize::new(1),
LeafNodeIndex::new(0u32),
);
let initial_confirmation_tag = message_secrets
.confirmation_key()
.tag(provider.crypto(), ciphersuite, &[])
.map_err(LibraryError::unexpected_crypto_error)?;
let message_secrets_store = MessageSecretsStore::new_with_secret(
mls_group_create_config.max_past_epochs(),
message_secrets,
);
let public_group = public_group_builder
.with_confirmation_tag(initial_confirmation_tag)
.build(provider.crypto())?;
let resumption_psk = group_epoch_secrets.resumption_psk();
resumption_psk_store.add(public_group.group_context().epoch(), resumption_psk.clone());
#[cfg(feature = "extensions-draft-08")]
let application_export_tree = ApplicationExportTree::new(application_exporter);
let mls_group = MlsGroup {
mls_group_config: mls_group_create_config.join_config.clone(),
own_leaf_nodes: vec![],
aad: vec![],
group_state: MlsGroupState::Operational,
public_group,
group_epoch_secrets,
own_leaf_index: LeafNodeIndex::new(0),
message_secrets_store,
resumption_psk_store,
#[cfg(feature = "extensions-draft-08")]
application_export_tree: Some(application_export_tree),
};
mls_group
.store(provider.storage())
.map_err(NewGroupError::StorageError)?;
mls_group
.store_epoch_keypairs(provider.storage(), &[leaf_keypair])
.map_err(NewGroupError::StorageError)?;
Ok(mls_group)
}
pub fn with_wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.wire_format_policy(wire_format_policy);
self
}
pub fn padding_size(mut self, padding_size: usize) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.padding_size(padding_size);
self
}
pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.max_past_epochs(max_past_epochs);
self
}
pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.number_of_resumption_psks(number_of_resumption_psks);
self
}
pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.use_ratchet_tree_extension(use_ratchet_tree_extension);
self
}
pub fn sender_ratchet_configuration(
mut self,
sender_ratchet_configuration: SenderRatchetConfiguration,
) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.sender_ratchet_configuration(sender_ratchet_configuration);
self
}
pub fn lifetime(mut self, lifetime: Lifetime) -> Self {
self.mls_group_create_config_builder =
self.mls_group_create_config_builder.lifetime(lifetime);
self
}
pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.ciphersuite(ciphersuite);
self
}
pub fn with_group_context_extensions(mut self, extensions: Extensions<GroupContext>) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.with_group_context_extensions(extensions);
self
}
pub fn with_leaf_node_extensions(
mut self,
extensions: Extensions<LeafNode>,
) -> Result<Self, LeafNodeValidationError> {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.with_leaf_node_extensions(extensions)?;
Ok(self)
}
pub fn with_capabilities(mut self, capabilities: Capabilities) -> Self {
self.mls_group_create_config_builder = self
.mls_group_create_config_builder
.capabilities(capabilities);
self
}
}