#[cfg(feature = "std")]
use std::collections::HashSet;
#[cfg(mls_build_async)]
use futures::{stream::FuturesUnordered, TryStreamExt};
use alloc::{vec, vec::Vec};
use mls_rs_core::{
crypto::{CipherSuite, SignatureSecretKey},
error::IntoAnyError,
extension::ExtensionList,
identity::{IdentityProvider, SigningIdentity},
protocol_version::ProtocolVersion,
};
use crate::{client::MlsError, tree_kem::TreeKemPublic, Client, Group, MlsMessage};
use crate::{group::Roster, time::MlsTime};
use super::{
proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
};
pub struct ReinitClient<C: ClientConfig + Clone> {
client: Client<C>,
reinit: ReInitProposal,
psk_input: PskSecretInput,
old_public_tree: TreeKemPublic,
}
impl<C> Group<C>
where
C: ClientConfig + Clone,
{
fn branch_group_creator(
&self,
timestamp: Option<MlsTime>,
group_id: Vec<u8>,
) -> Result<GroupCreator<C>, MlsError> {
Ok(GroupCreator {
group_id,
cipher_suite: self.cipher_suite(),
version: self.protocol_version(),
extensions: self.group_state().context.extensions.clone(),
psk_input: self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
timestamp,
signer: self.signer.clone(),
config: self.config.clone(),
typ: GroupCreationType::Branch,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn branch(
&self,
sub_group_id: Vec<u8>,
new_key_packages: Vec<MlsMessage>,
timestamp: Option<MlsTime>,
) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
self.branch_group_creator(timestamp, sub_group_id)?
.create(
new_key_packages,
self.current_member_signing_identity()?.clone(),
self.current_user_leaf_node()?.ungreased_extensions(),
self.roster(),
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn join_subgroup(
&self,
welcome: &MlsMessage,
tree_data: Option<ExportedTree<'_>>,
timestamp: Option<MlsTime>,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
self.branch_group_creator(timestamp, vec![])?
.join(welcome, tree_data, false, self.roster())
.await
}
pub fn get_reinit_client(
self,
new_signer: Option<SignatureSecretKey>,
new_signing_identity: Option<SigningIdentity>,
) -> Result<ReinitClient<C>, MlsError> {
let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
let new_signing_identity = new_signing_identity
.map(Ok)
.unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
let reinit = self
.state
.pending_reinit
.ok_or(MlsError::PendingReInitNotFound)?;
let new_signer = match new_signer {
Some(signer) => signer,
None => self.signer,
};
let client = Client::new(
self.config,
Some(new_signer),
Some((new_signing_identity, reinit.new_cipher_suite())),
reinit.new_version(),
);
Ok(ReinitClient {
client,
reinit,
psk_input,
old_public_tree: self.state.public_tree,
})
}
fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
let psk = self.epoch_secrets.resumption_secret.clone();
let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
usage,
psk_group_id: PskGroupId(self.group_id().to_vec()),
psk_epoch: self.current_epoch(),
});
let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
Ok(PskSecretInput { id, psk })
}
}
impl<C: ClientConfig + Clone> ReinitClient<C> {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn generate_key_package(
&self,
timestamp: Option<MlsTime>,
) -> Result<MlsMessage, MlsError> {
self.client
.generate_key_package_message(Default::default(), Default::default(), timestamp)
.await
}
fn group_creator(self, timestamp: Option<MlsTime>) -> GroupCreator<C> {
GroupCreator {
group_id: self.reinit.group_id,
cipher_suite: self.reinit.cipher_suite,
version: self.reinit.version,
extensions: self.reinit.extensions,
psk_input: self.psk_input,
timestamp,
signer: self.client.signer.unwrap(),
config: self.client.config,
typ: GroupCreationType::Reinit,
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn commit(
mut self,
new_key_packages: Vec<MlsMessage>,
new_leaf_node_extensions: ExtensionList,
timestamp: Option<MlsTime>,
) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
let signing_identity = self.client.signing_identity.take();
let old_public_tree = core::mem::take(&mut self.old_public_tree);
self.group_creator(timestamp)
.create(
new_key_packages,
signing_identity.unwrap().0,
new_leaf_node_extensions,
old_public_tree.roster(),
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn join(
mut self,
welcome: &MlsMessage,
tree_data: Option<ExportedTree<'_>>,
timestamp: Option<MlsTime>,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
let old_public_tree = core::mem::take(&mut self.old_public_tree);
self.group_creator(timestamp)
.join(welcome, tree_data, true, old_public_tree.roster())
.await
}
}
struct GroupCreator<C> {
group_id: Vec<u8>,
cipher_suite: CipherSuite,
version: ProtocolVersion,
extensions: ExtensionList,
psk_input: PskSecretInput,
timestamp: Option<MlsTime>,
signer: SignatureSecretKey,
config: C,
typ: GroupCreationType,
}
impl<C: ClientConfig> GroupCreator<C> {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn create(
self,
new_key_packages: Vec<MlsMessage>,
signing_identity: SigningIdentity,
leaf_node_extensions: ExtensionList,
old_roster: Roster<'_>,
) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
let mut group = Group::new(
self.config,
Some(self.group_id),
self.cipher_suite,
self.version,
signing_identity,
self.extensions.clone(),
leaf_node_extensions,
self.signer,
self.timestamp,
)
.await?;
group.previous_psk = Some(self.psk_input);
let mut commit = group.commit_builder();
for kp in new_key_packages.into_iter() {
commit = commit.add_member(kp)?;
}
let commit = commit.build().await?;
group.apply_pending_commit().await?;
group.previous_psk = None;
check_that_subgroup_is_a_subset(old_roster, &group, self.typ).await?;
Ok((group, commit.welcome_messages))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn join(
self,
welcome: &MlsMessage,
tree_data: Option<ExportedTree<'_>>,
verify_group_id: bool,
old_roster: Roster<'_>,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
let (group, new_member_info) = Group::from_welcome_message(
welcome,
tree_data,
self.config,
self.signer,
Some(self.psk_input),
self.timestamp,
)
.await?;
check_that_subgroup_is_a_subset(old_roster, &group, self.typ).await?;
if group.protocol_version() != self.version {
Err(MlsError::ProtocolVersionMismatch)
} else if group.cipher_suite() != self.cipher_suite {
Err(MlsError::CipherSuiteMismatch)
}
else if group.current_epoch() != 1 {
Err(MlsError::InitialEpochNotOne)
} else if verify_group_id && group.group_id() != self.group_id {
Err(MlsError::GroupIdMismatch)
} else if group.group_state().context.extensions != self.extensions {
Err(MlsError::ReInitExtensionsMismatch)
} else {
Ok((group, new_member_info))
}
}
}
enum GroupCreationType {
Branch,
Reinit,
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn check_that_subgroup_is_a_subset<C: ClientConfig>(
old_roster: Roster<'_>,
new_group: &Group<C>,
typ: GroupCreationType,
) -> Result<(), MlsError> {
if matches!(typ, GroupCreationType::Reinit)
&& old_roster.public_tree.len() != new_group.roster().public_tree.len()
{
return Err(MlsError::NotASubgroup);
}
let provider = new_group.identity_provider();
let extensions = new_group.context().extensions();
let old_identities = collect_identities(extensions, old_roster, &provider).await?;
let new_identities = collect_identities(extensions, new_group.roster(), &provider).await?;
new_identities
.is_subset(&old_identities)
.then_some(())
.ok_or(MlsError::NotASubgroup)?;
Ok(())
}
#[cfg(feature = "std")]
#[cfg(not(mls_build_async))]
fn collect_identities<I: IdentityProvider>(
extensions: &ExtensionList,
roster: Roster<'_>,
provider: &I,
) -> Result<HashSet<Vec<u8>>, MlsError> {
roster
.members_iter()
.map(|m| {
provider
.identity(&m.signing_identity, extensions)
.map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
})
.collect()
}
#[cfg(feature = "std")]
#[cfg(mls_build_async)]
async fn collect_identities<I: IdentityProvider>(
extensions: &ExtensionList,
roster: Roster<'_>,
provider: &I,
) -> Result<HashSet<Vec<u8>>, MlsError> {
roster
.members_iter()
.map(async move |m| {
provider
.identity(&m.signing_identity, extensions)
.await
.map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
})
.collect::<FuturesUnordered<_>>()
.try_collect()
.await
}
#[cfg(not(feature = "std"))]
struct Identities(Vec<Vec<u8>>);
#[cfg(not(feature = "std"))]
impl Identities {
fn is_subset(&self, other: &Self) -> bool {
self.0.iter().all(|i| other.0.contains(i))
}
}
#[cfg(not(feature = "std"))]
fn collect_identities<I: IdentityProvider>(
extensions: &ExtensionList,
roster: Roster<'_>,
provider: &I,
) -> Result<Identities, MlsError> {
roster
.members_iter()
.map(|m| {
provider
.identity(&m.signing_identity, extensions)
.map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
})
.collect::<Result<Vec<_>, _>>()
.map(Identities)
}