use crate::storage::OpenMlsProvider;
use crate::test_utils::OpenMlsRustCrypto;
use crate::treesync::LeafNodeParameters;
use crate::{
binary_tree::array_representation::LeafNodeIndex,
ciphersuite::{hash_ref::KeyPackageRef, *},
credentials::*,
framing::*,
group::*,
key_packages::*,
messages::*,
treesync::{node::Node, LeafNode, RatchetTree, RatchetTreeIn},
};
use ::rand::{rngs::OsRng, RngCore, TryRngCore};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::{
crypto::OpenMlsCrypto,
types::{Ciphersuite, HpkeKeyPair, SignatureScheme},
OpenMlsProvider as _,
};
use std::{collections::HashMap, sync::RwLock};
use tls_codec::*;
pub mod client;
pub mod errors;
use self::client::*;
use self::errors::*;
#[derive(Clone)]
pub struct Group {
pub group_id: GroupId,
pub members: Vec<(usize, Vec<u8>)>,
pub ciphersuite: Ciphersuite,
pub group_config: MlsGroupJoinConfig,
pub public_tree: RatchetTree,
pub exporter_secret: Vec<u8>,
}
impl Group {
pub fn random_group_member(&self) -> (u32, Vec<u8>) {
let index = (OsRng.unwrap_mut().next_u32() as usize) % self.members.len();
let (i, identity) = self.members[index].clone();
(i as u32, identity)
}
pub fn group_id(&self) -> &GroupId {
&self.group_id
}
pub fn members(&self) -> impl Iterator<Item = (u32, Vec<u8>)> + '_ {
self.members
.clone()
.into_iter()
.map(|(index, id)| (index as u32, id))
}
}
#[derive(Debug)]
pub enum ActionType {
Commit,
Proposal,
}
#[derive(Debug, PartialEq, Eq)]
pub enum CodecUse {
SerializedMessages,
StructMessages,
}
pub struct MlsGroupTestSetup<Provider: OpenMlsProvider> {
pub clients: RwLock<HashMap<Vec<u8>, RwLock<Client<Provider>>>>,
pub groups: RwLock<HashMap<GroupId, Group>>,
pub waiting_for_welcome: RwLock<HashMap<Vec<u8>, Vec<u8>>>,
pub default_mgp: MlsGroupCreateConfig,
pub use_codec: CodecUse,
}
impl<Provider: OpenMlsProvider + Default> MlsGroupTestSetup<Provider> {
pub fn new(
default_mgp: MlsGroupCreateConfig,
number_of_clients: usize,
use_codec: CodecUse,
) -> Self {
let mut clients = HashMap::new();
for i in 0..number_of_clients {
let identity = i.to_be_bytes().to_vec();
let provider = Provider::default();
let mut credentials = HashMap::new();
for ciphersuite in provider.crypto().supported_ciphersuites().iter() {
let credential = BasicCredential::new(identity.clone());
let signature_keys =
SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap();
signature_keys.store(provider.storage()).unwrap();
let signature_key = OpenMlsSignaturePublicKey::new(
signature_keys.public().into(),
signature_keys.signature_scheme(),
)
.unwrap();
credentials.insert(
*ciphersuite,
CredentialWithKey {
credential: credential.into(),
signature_key: signature_key.into(),
},
);
}
let client = Client {
identity: identity.clone(),
credentials,
provider,
groups: RwLock::new(HashMap::new()),
};
clients.insert(identity, RwLock::new(client));
}
let groups = RwLock::new(HashMap::new());
let waiting_for_welcome = RwLock::new(HashMap::new());
MlsGroupTestSetup {
clients: RwLock::new(clients),
groups,
waiting_for_welcome,
default_mgp,
use_codec,
}
}
pub fn get_fresh_key_package(
&self,
client: &Client<Provider>,
ciphersuite: Ciphersuite,
) -> Result<KeyPackage, SetupError<Provider::StorageError>> {
let key_package = client.get_fresh_key_package(ciphersuite)?;
self.waiting_for_welcome
.write()
.expect("An unexpected error occurred.")
.insert(
key_package
.hash_ref(client.provider.crypto())?
.as_slice()
.to_vec(),
client.identity.clone(),
);
Ok(key_package)
}
pub fn identity_by_index(&self, index: usize, group: &Group) -> Option<Vec<u8>> {
let (_, id) = group
.members
.iter()
.find(|(leaf_index, _)| index == *leaf_index)
.expect("Couldn't find member at leaf index");
let clients = self.clients.read().expect("An unexpected error occurred.");
let client = clients
.get(id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
client.identity(&group.group_id)
}
pub fn identity_by_id(&self, id: &[u8], group: &Group) -> Option<Vec<u8>> {
let (_, id) = group
.members
.iter()
.find(|(_, leaf_id)| id == leaf_id)
.expect("Couldn't find member at leaf index");
let clients = self.clients.read().expect("An unexpected error occurred.");
let client = clients
.get(id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
client.identity(&group.group_id)
}
pub fn deliver_welcome(
&self,
welcome: Welcome,
group: &Group,
) -> Result<(), SetupError<Provider::StorageError>> {
let welcome = match self.use_codec {
CodecUse::SerializedMessages => {
let serialized_welcome = welcome
.tls_serialize_detached()
.map_err(ClientError::TlsCodecError)?;
Welcome::tls_deserialize(&mut serialized_welcome.as_slice())
.map_err(ClientError::TlsCodecError)?
}
CodecUse::StructMessages => welcome,
};
let clients = self.clients.read().expect("An unexpected error occurred.");
for egs in welcome.secrets() {
let client_id = self
.waiting_for_welcome
.write()
.expect("An unexpected error occurred.")
.remove(egs.new_member().as_slice())
.ok_or(SetupError::NoFreshKeyPackage)?;
let client = clients
.get(&client_id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
client.join_group(
group.group_config.clone(),
welcome.clone(),
Some(group.public_tree.clone().into()),
)?;
}
Ok(())
}
pub fn distribute_to_members<AS: Fn(&Credential) -> bool>(
&self,
sender_id: &[u8],
group: &mut Group,
message: &MlsMessageIn,
authentication_service: &AS,
) -> Result<(), ClientError<Provider::StorageError>> {
let message: ProtocolMessage = match self.use_codec {
CodecUse::SerializedMessages => {
let mls_message_out: MlsMessageOut = message.clone().into();
let serialized_message = mls_message_out
.tls_serialize_detached()
.map_err(ClientError::TlsCodecError)?;
MlsMessageIn::tls_deserialize(&mut serialized_message.as_slice())
.map_err(ClientError::TlsCodecError)?
}
CodecUse::StructMessages => message.clone(),
}
.into_protocol_message()
.expect("Unexptected message type.");
let clients = self.clients.read().expect("An unexpected error occurred.");
let results: Result<Vec<_>, _> = group
.members
.iter()
.filter_map(|(_index, member_id)| {
if message.content_type() == ContentType::Application && member_id == sender_id {
None
} else {
Some(member_id)
}
})
.map(|member_id| {
let member = clients
.get(member_id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
member.receive_messages_for_group(&message, sender_id, &authentication_service)
})
.collect();
results?;
let sender = clients
.get(sender_id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
let sender_groups = sender.groups.read().expect("An unexpected error occurred.");
let sender_group = sender_groups
.get(&group.group_id)
.expect("An unexpected error occurred.");
group.members = sender
.get_members_of_group(&group.group_id)?
.iter()
.map(
|Member {
index, credential, ..
}| { (index.usize(), credential.serialized_content().to_vec()) },
)
.collect();
group.public_tree = sender_group.export_ratchet_tree();
group.exporter_secret = sender_group
.export_secret(sender.provider.crypto(), "test", &[], 32)
.map_err(ClientError::ExportSecretError)?;
Ok(())
}
pub fn check_group_states<AS: Fn(&Credential) -> bool>(
&self,
group: &mut Group,
authentication_service: AS,
) {
let clients = self.clients.read().expect("An unexpected error occurred.");
let group_members = group.members.iter();
let messages = group_members
.filter_map(|(_, m_id)| {
let m = clients
.get(m_id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
let mut group_states = m.groups.write().expect("An unexpected error occurred.");
if let Some(group_state) = group_states.get_mut(&group.group_id) {
assert_eq!(group_state.export_ratchet_tree(), group.public_tree);
assert_eq!(
group_state
.export_secret(m.provider.crypto(), "test", &[], 32)
.expect("An unexpected error occurred."),
group.exporter_secret
);
let signature_pk = group_state.own_leaf().unwrap().signature_key();
let signer = SignatureKeyPair::read(
m.provider.storage(),
signature_pk.as_slice(),
group_state.ciphersuite().signature_algorithm(),
)
.unwrap();
let message = group_state
.create_message(&m.provider, &signer, "Hello World!".as_bytes())
.expect("Error composing message while checking group states.");
Some((m_id.to_vec(), message))
} else {
None
}
})
.collect::<Vec<(Vec<u8>, MlsMessageOut)>>();
drop(clients);
for (sender_id, message) in messages {
self.distribute_to_members(&sender_id, group, &message.into(), &authentication_service)
.expect("Error sending messages to clients while checking group states.");
}
}
pub fn random_new_members_for_group(
&self,
group: &Group,
number_of_members: usize,
) -> Result<Vec<Vec<u8>>, SetupError<Provider::StorageError>> {
let clients = self.clients.read().expect("An unexpected error occurred.");
if number_of_members + group.members.len() > clients.len() {
return Err(SetupError::NotEnoughClients);
}
let mut new_member_ids: Vec<Vec<u8>> = Vec::new();
for _ in 0..number_of_members {
let is_in_new_members = |client_id| {
new_member_ids
.iter()
.any(|new_member_id| client_id == new_member_id)
};
let is_in_group = |client_id| {
group
.members
.iter()
.any(|(_, member_id)| client_id == member_id)
};
let new_member_id = clients
.keys()
.find(|&client_id| !is_in_group(client_id) && !is_in_new_members(client_id))
.expect("An unexpected error occurred.");
new_member_ids.push(new_member_id.clone());
}
Ok(new_member_ids)
}
pub fn create_group(
&self,
ciphersuite: Ciphersuite,
) -> Result<GroupId, SetupError<Provider::StorageError>> {
let clients = self.clients.read().expect("An unexpected error occurred.");
let group_creator_id = ((OsRng.unwrap_mut().next_u32() as usize) % clients.len())
.to_be_bytes()
.to_vec();
let group_creator = clients
.get(&group_creator_id)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
let mut groups = self.groups.write().expect("An unexpected error occurred.");
let group_id = group_creator.create_group(self.default_mgp.clone(), ciphersuite)?;
let creator_groups = group_creator
.groups
.read()
.expect("An unexpected error occurred.");
let group = creator_groups
.get(&group_id)
.expect("An unexpected error occurred.");
let public_tree = group.export_ratchet_tree();
let exporter_secret =
group.export_secret(group_creator.provider.crypto(), "test", &[], 32)?;
let member_ids = vec![(0, group_creator_id)];
let group = Group {
group_id: group_id.clone(),
members: member_ids,
ciphersuite,
group_config: self.default_mgp.join_config.clone(),
public_tree,
exporter_secret,
};
groups.insert(group_id.clone(), group);
Ok(group_id)
}
pub fn create_random_group<AS: Fn(&Credential) -> bool>(
&self,
target_group_size: usize,
ciphersuite: Ciphersuite,
authentication_service: AS,
) -> Result<GroupId, SetupError<Provider::StorageError>> {
let group_id = self.create_group(ciphersuite)?;
let mut groups = self.groups.write().expect("An unexpected error occurred.");
let group = groups
.get_mut(&group_id)
.expect("An unexpected error occurred.");
let mut new_members = self.random_new_members_for_group(group, target_group_size - 1)?;
while !new_members.is_empty() {
let adder_id = group.random_group_member();
let number_of_adds =
((OsRng.unwrap_mut().next_u32() as usize) % 5 % new_members.len()) + 1;
let members_to_add = new_members.drain(0..number_of_adds).collect();
self.add_clients(
ActionType::Commit,
group,
&adder_id.1,
members_to_add,
&authentication_service,
)?;
}
Ok(group_id)
}
pub fn self_update<AS: Fn(&Credential) -> bool>(
&self,
action_type: ActionType,
group: &mut Group,
client_id: &[u8],
leaf_node_parameters: LeafNodeParameters,
authentication_service: &AS,
) -> Result<(), SetupError<Provider::StorageError>> {
let clients = self.clients.read().expect("An unexpected error occurred.");
let client = clients
.get(client_id)
.ok_or(SetupError::UnknownClientId)?
.read()
.expect("An unexpected error occurred.");
let (messages, welcome_option, _) =
client.self_update(action_type, &group.group_id, leaf_node_parameters)?;
self.distribute_to_members(
&client.identity,
group,
&messages.into(),
authentication_service,
)?;
if let Some(welcome) = welcome_option {
self.deliver_welcome(welcome, group)?;
}
Ok(())
}
pub fn add_clients<AS: Fn(&Credential) -> bool>(
&self,
action_type: ActionType,
group: &mut Group,
adder_id: &[u8],
addees: Vec<Vec<u8>>,
authentication_service: &AS,
) -> Result<(), SetupError<Provider::StorageError>> {
let clients = self.clients.read().expect("An unexpected error occurred.");
let adder = clients
.get(adder_id)
.ok_or(SetupError::UnknownClientId)?
.read()
.expect("An unexpected error occurred.");
if group
.members
.iter()
.any(|(_, id)| addees.iter().any(|client_id| client_id == id))
{
return Err(SetupError::ClientAlreadyInGroup);
}
let mut key_packages = Vec::new();
for addee_id in &addees {
let addee = clients
.get(addee_id)
.ok_or(SetupError::UnknownClientId)?
.read()
.expect("An unexpected error occurred.");
let key_package = self.get_fresh_key_package(&addee, group.ciphersuite)?;
key_packages.push(key_package);
}
let (messages, welcome_option, _) =
adder.add_members(action_type, &group.group_id, &key_packages)?;
for message in messages {
self.distribute_to_members(adder_id, group, &message.into(), authentication_service)?;
}
if let Some(welcome) = welcome_option {
self.deliver_welcome(welcome, group)?;
}
Ok(())
}
pub fn remove_clients<AS: Fn(&Credential) -> bool>(
&self,
action_type: ActionType,
group: &mut Group,
remover_id: &[u8],
target_members: &[LeafNodeIndex],
authentication_service: AS,
) -> Result<(), SetupError<Provider::StorageError>> {
let clients = self.clients.read().expect("An unexpected error occurred.");
let remover = clients
.get(remover_id)
.ok_or(SetupError::UnknownClientId)?
.read()
.expect("An unexpected error occurred.");
let (messages, welcome_option, _) =
remover.remove_members(action_type, &group.group_id, target_members)?;
for message in messages {
self.distribute_to_members(
remover_id,
group,
&message.into(),
&authentication_service,
)?;
}
if let Some(welcome) = welcome_option {
self.deliver_welcome(welcome, group)?;
}
Ok(())
}
pub fn perform_random_operation<AS: Fn(&Credential) -> bool>(
&self,
group: &mut Group,
authentication_service: &AS,
) -> Result<(), SetupError<Provider::StorageError>> {
let mut rng = OsRng;
let mut rng = rng.unwrap_mut();
let member_id = group.random_group_member();
println!("Member performing the operation: {member_id:?}");
let action_type = match (rng.next_u32() as usize) % 2 {
0 => ActionType::Proposal,
1 => ActionType::Commit,
_ => return Err(SetupError::Unknown),
};
let operation_type = (rng.next_u32() as usize) % 3;
match operation_type {
0 => {
println!("Performing a self-update with action type: {action_type:?}");
self.self_update(
action_type,
group,
&member_id.1,
LeafNodeParameters::default(),
authentication_service,
)?;
}
1 => {
if group.members.len() > 1 {
let number_of_removals =
(((rng.next_u32() as usize) % group.members.len()) % 5) + 1;
let (own_index, _) = group
.members
.iter()
.find(|(_, identity)| identity == &member_id.1)
.expect("An unexpected error occurred.")
.clone();
println!("Index of the member performing the {action_type:?}: {own_index:?}");
let mut target_member_leaf_indices = Vec::new();
let mut target_member_identities = Vec::new();
let clients = self.clients.read().expect("An unexpected error occurred.");
println!("Removing members:");
for _ in 0..number_of_removals {
let mut member_list_index = (rng.next_u32() as usize) % group.members.len();
let (mut leaf_index, mut identity) =
group.members[member_list_index].clone();
while leaf_index == own_index
|| target_member_identities.contains(&identity)
{
member_list_index = (rng.next_u32() as usize) % group.members.len();
let (new_leaf_index, new_identity) =
group.members[member_list_index].clone();
leaf_index = new_leaf_index;
identity = new_identity;
}
let client = clients
.get(&identity)
.expect("An unexpected error occurred.")
.read()
.expect("An unexpected error occurred.");
let client_group =
client.groups.read().expect("An unexpected error occurred.");
let client_group = client_group
.get(&group.group_id)
.expect("An unexpected error occurred.");
target_member_leaf_indices.push(client_group.own_leaf_index());
target_member_identities.push(identity);
}
self.remove_clients(
action_type,
group,
&member_id.1,
&target_member_leaf_indices,
authentication_service,
)?
};
}
2 => {
let clients_left = self
.clients
.read()
.expect("An unexpected error occurred.")
.len()
- group.members.len();
if clients_left > 0 {
let number_of_adds = (((rng.next_u32() as usize) % clients_left) % 5) + 1;
let new_member_ids = self
.random_new_members_for_group(group, number_of_adds)
.expect("An unexpected error occurred.");
println!("{action_type:?}: Adding new clients: {new_member_ids:?}");
self.add_clients(
action_type,
group,
&member_id.1,
new_member_ids,
authentication_service,
)?;
}
}
_ => return Err(SetupError::Unknown),
};
Ok(())
}
}
pub fn noop_authentication_service(_cred: &Credential) -> bool {
true
}