use std::{collections::HashMap, sync::RwLock};
use commit_builder::CommitMessageBundle;
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::{
types::{Ciphersuite, HpkeKeyPair, SignatureScheme},
OpenMlsProvider as _,
};
use tls_codec::{Deserialize, Serialize};
use super::OpenMlsRustCrypto;
use crate::{
binary_tree::array_representation::LeafNodeIndex,
ciphersuite::hash_ref::KeyPackageRef,
credentials::*,
extensions::*,
framing::*,
group::*,
key_packages::*,
messages::{group_info::GroupInfo, *},
storage::OpenMlsProvider,
treesync::{
node::{leaf_node::Capabilities, Node},
LeafNode, LeafNodeParameters, RatchetTree, RatchetTreeIn,
},
versions::ProtocolVersion,
};
use super::{errors::ClientError, ActionType};
#[derive(Debug)]
pub struct Client<Provider: OpenMlsProvider> {
pub identity: Vec<u8>,
pub credentials: HashMap<Ciphersuite, CredentialWithKey>,
pub provider: Provider,
pub groups: RwLock<HashMap<GroupId, MlsGroup>>,
}
impl<Provider: OpenMlsProvider> Client<Provider> {
pub fn get_fresh_key_package(
&self,
ciphersuite: Ciphersuite,
) -> Result<KeyPackage, ClientError<Provider::StorageError>> {
let credential_with_key = self
.credentials
.get(&ciphersuite)
.ok_or(ClientError::CiphersuiteNotSupported)?;
let keys = SignatureKeyPair::read(
self.provider.storage(),
credential_with_key.signature_key.as_slice(),
ciphersuite.signature_algorithm(),
)
.unwrap();
let key_package = KeyPackage::builder()
.build(
ciphersuite,
&self.provider,
&keys,
credential_with_key.clone(),
)
.unwrap();
Ok(key_package.key_package)
}
pub fn create_group(
&self,
mls_group_create_config: MlsGroupCreateConfig,
ciphersuite: Ciphersuite,
) -> Result<GroupId, ClientError<Provider::StorageError>> {
let credential_with_key = self
.credentials
.get(&ciphersuite)
.ok_or(ClientError::CiphersuiteNotSupported);
let credential_with_key = credential_with_key?;
let signer = SignatureKeyPair::read(
self.provider.storage(),
credential_with_key.signature_key.as_slice(),
ciphersuite.signature_algorithm(),
)
.unwrap();
let group_state = MlsGroup::new(
&self.provider,
&signer,
&mls_group_create_config,
credential_with_key.clone(),
)?;
let group_id = group_state.group_id().clone();
self.groups
.write()
.expect("An unexpected error occurred.")
.insert(group_state.group_id().clone(), group_state);
Ok(group_id)
}
pub fn join_group(
&self,
mls_group_config: MlsGroupJoinConfig,
welcome: Welcome,
ratchet_tree: Option<RatchetTreeIn>,
) -> Result<(), ClientError<Provider::StorageError>> {
let staged_join = StagedWelcome::new_from_welcome(
&self.provider,
&mls_group_config,
welcome,
ratchet_tree,
)?;
let new_group = staged_join.into_group(&self.provider)?;
self.groups
.write()
.expect("An unexpected error occurred.")
.insert(new_group.group_id().to_owned(), new_group);
Ok(())
}
pub fn receive_messages_for_group<AS: Fn(&Credential) -> bool>(
&self,
message: &ProtocolMessage,
sender_id: &[u8],
authentication_service: &AS,
) -> Result<(), ClientError<Provider::StorageError>> {
let mut group_states = self.groups.write().expect("An unexpected error occurred.");
let group_id = message.group_id();
let group_state = group_states
.get_mut(group_id)
.ok_or(ClientError::NoMatchingGroup)?;
if sender_id == self.identity && message.content_type() == ContentType::Commit {
group_state.merge_pending_commit(&self.provider)?
} else {
if message.content_type() == ContentType::Commit {
group_state.clear_pending_commit(self.provider.storage())?;
}
let processed_message = group_state
.process_message(&self.provider, message.clone())
.map_err(ClientError::ProcessMessageError)?;
match processed_message.into_content() {
ProcessedMessageContent::ApplicationMessage(_) => {}
ProcessedMessageContent::ProposalMessage(staged_proposal) => {
group_state
.store_pending_proposal(self.provider.storage(), *staged_proposal)?;
}
ProcessedMessageContent::ExternalJoinProposalMessage(staged_proposal) => {
group_state
.store_pending_proposal(self.provider.storage(), *staged_proposal)?;
}
ProcessedMessageContent::StagedCommitMessage(staged_commit) => {
for credential in staged_commit.credentials_to_verify() {
if !authentication_service(credential) {
println!(
"authentication service callback denied credential {credential:?}"
);
return Err(ClientError::NoMatchingCredential);
}
}
group_state.merge_staged_commit(&self.provider, *staged_commit)?;
}
}
}
Ok(())
}
pub fn get_members_of_group(
&self,
group_id: &GroupId,
) -> Result<Vec<Member>, ClientError<Provider::StorageError>> {
let groups = self.groups.read().expect("An unexpected error occurred.");
let group = groups.get(group_id).ok_or(ClientError::NoMatchingGroup)?;
let members = group.members().collect();
Ok(members)
}
#[allow(clippy::type_complexity)]
pub fn self_update(
&self,
action_type: ActionType,
group_id: &GroupId,
leaf_node_parameters: LeafNodeParameters,
) -> Result<
(MlsMessageOut, Option<Welcome>, Option<GroupInfo>),
ClientError<Provider::StorageError>,
> {
let mut groups = self.groups.write().expect("An unexpected error occurred.");
let group = groups
.get_mut(group_id)
.ok_or(ClientError::NoMatchingGroup)?;
let signature_pk = group.own_leaf().unwrap().signature_key();
let signer = SignatureKeyPair::read(
self.provider.storage(),
signature_pk.as_slice(),
group.ciphersuite().signature_algorithm(),
)
.unwrap();
let (msg, welcome_option, group_info) = match action_type {
ActionType::Commit => {
let bundle =
group.self_update(&self.provider, &signer, LeafNodeParameters::default())?;
let welcome = bundle.to_welcome_msg();
let (msg, _, group_info) = bundle.into_contents();
(msg, welcome, group_info)
}
ActionType::Proposal => {
let (msg, _) =
group.propose_self_update(&self.provider, &signer, leaf_node_parameters)?;
(msg, None, None)
}
};
Ok((
msg,
welcome_option.map(|w| w.into_welcome().expect("Unexpected message type.")),
group_info,
))
}
#[allow(clippy::type_complexity)]
pub fn add_members(
&self,
action_type: ActionType,
group_id: &GroupId,
key_packages: &[KeyPackage],
) -> Result<
(Vec<MlsMessageOut>, Option<Welcome>, Option<GroupInfo>),
ClientError<Provider::StorageError>,
> {
let mut groups = self.groups.write().expect("An unexpected error occurred.");
let group = groups
.get_mut(group_id)
.ok_or(ClientError::NoMatchingGroup)?;
let signature_pk = group.own_leaf().unwrap().signature_key();
let signer = SignatureKeyPair::read(
self.provider.storage(),
signature_pk.as_slice(),
group.ciphersuite().signature_algorithm(),
)
.unwrap();
let action_results = match action_type {
ActionType::Commit => {
let (messages, welcome_message, group_info) =
group.add_members(&self.provider, &signer, key_packages)?;
(
vec![messages],
Some(
welcome_message
.into_welcome()
.expect("Unexpected message type."),
),
group_info,
)
}
ActionType::Proposal => {
let mut messages = Vec::new();
for key_package in key_packages {
let message = group
.propose_add_member(&self.provider, &signer, key_package)
.map(|(out, _)| out)?;
messages.push(message);
}
(messages, None, None)
}
};
Ok(action_results)
}
#[allow(clippy::type_complexity)]
pub fn remove_members(
&self,
action_type: ActionType,
group_id: &GroupId,
targets: &[LeafNodeIndex],
) -> Result<
(Vec<MlsMessageOut>, Option<Welcome>, Option<GroupInfo>),
ClientError<Provider::StorageError>,
> {
let mut groups = self.groups.write().expect("An unexpected error occurred.");
let group = groups
.get_mut(group_id)
.ok_or(ClientError::NoMatchingGroup)?;
let signature_pk = group.own_leaf().unwrap().signature_key();
let signer = SignatureKeyPair::read(
self.provider.storage(),
signature_pk.as_slice(),
group.ciphersuite().signature_algorithm(),
)
.unwrap();
let action_results = match action_type {
ActionType::Commit => {
let (message, welcome_option, group_info) =
group.remove_members(&self.provider, &signer, targets)?;
(
vec![message],
welcome_option.map(|w| w.into_welcome().expect("Unexpected message type.")),
group_info,
)
}
ActionType::Proposal => {
let mut messages = Vec::new();
for target in targets {
let message = group
.propose_remove_member(&self.provider, &signer, *target)
.map(|(out, _)| out)?;
messages.push(message);
}
(messages, None, None)
}
};
Ok(action_results)
}
pub fn identity(&self, group_id: &GroupId) -> Option<Vec<u8>> {
let groups = self.groups.read().unwrap();
let group = groups.get(group_id).unwrap();
let leaf = group.own_leaf();
leaf.map(|l| {
let credential = BasicCredential::try_from(l.credential().clone()).unwrap();
credential.identity().to_vec()
})
}
}