#![cfg(feature = "extensions-draft-08")]
use openmls::prelude::*;
use openmls::test_utils::single_group_test_framework::*;
use openmls_test::openmls_test;
const COUNTER_COMPONENT_ID: u16 = 0xf042;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CounterOperation {
Increment = 0x01,
Decrement = 0x02,
}
impl CounterOperation {
fn from_byte(byte: u8) -> Option<Self> {
match byte {
0x01 => Some(CounterOperation::Increment),
0x02 => Some(CounterOperation::Decrement),
_ => None,
}
}
fn to_bytes(self) -> Vec<u8> {
vec![self as u8]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum CounterError {
Underflow,
InvalidOperation,
}
fn process_counter_updates<'a>(
current_value: Option<&[u8]>,
updates: impl Iterator<Item = &'a [u8]>,
) -> Result<Vec<u8>, CounterError> {
let mut counter: u32 = current_value
.map(|bytes| {
let arr: [u8; 4] = bytes.try_into().unwrap_or([0; 4]);
u32::from_be_bytes(arr)
})
.unwrap_or(0);
for update in updates {
let op_byte = update.first().ok_or(CounterError::InvalidOperation)?;
let op = CounterOperation::from_byte(*op_byte).ok_or(CounterError::InvalidOperation)?;
match op {
CounterOperation::Increment => {
counter = counter.saturating_add(1);
}
CounterOperation::Decrement => {
counter = counter.checked_sub(1).ok_or(CounterError::Underflow)?;
}
}
}
Ok(counter.to_be_bytes().to_vec())
}
fn setup_group_with_app_data_support<'a, Provider: OpenMlsProvider>(
alice_party: &'a CorePartyState<Provider>,
bob_party: &'a CorePartyState<Provider>,
ciphersuite: Ciphersuite,
) -> GroupState<'a, Provider> {
let capabilities = Capabilities::new(
None, None, Some(&[ExtensionType::AppDataDictionary]),
Some(&[ProposalType::AppDataUpdate]),
None, );
let required_capabilities_extension =
Extension::RequiredCapabilities(RequiredCapabilitiesExtension::new(
&[ExtensionType::AppDataDictionary], &[ProposalType::AppDataUpdate], &[], ));
let alice_pre_group = alice_party
.pre_group_builder(ciphersuite)
.with_leaf_node_capabilities(capabilities.clone())
.build();
let bob_pre_group = bob_party
.pre_group_builder(ciphersuite)
.with_leaf_node_capabilities(capabilities.clone())
.build();
let create_config = MlsGroupCreateConfig::builder()
.ciphersuite(ciphersuite)
.capabilities(capabilities)
.use_ratchet_tree_extension(true)
.with_group_context_extensions(
Extensions::single(required_capabilities_extension).expect("valid extensions"),
)
.build();
let join_config = create_config.join_config().clone();
let mut group_state = GroupState::new_from_party(
GroupId::from_slice(b"CounterGroup"),
alice_pre_group,
create_config,
)
.expect("failed to create group");
group_state
.add_member(AddMemberConfig {
adder: "alice",
addees: vec![bob_pre_group],
join_config,
tree: None,
})
.expect("failed to add Bob");
group_state
}
fn process_app_data_proposals<'a>(
updater: &mut AppDataDictionaryUpdater<'a>,
proposals: impl Iterator<Item = &'a AppDataUpdateProposal>,
) -> Result<(), CounterError> {
use openmls::component::ComponentData;
let mut counter_updates: Vec<&[u8]> = Vec::new();
for proposal in proposals {
if proposal.component_id() != COUNTER_COMPONENT_ID {
continue;
}
match proposal.operation() {
AppDataUpdateOperation::Update(data) => {
counter_updates.push(data.as_ref());
}
AppDataUpdateOperation::Remove => {
updater.remove(&COUNTER_COMPONENT_ID);
return Ok(());
}
}
}
if counter_updates.is_empty() {
return Ok(());
}
let current_value = updater.old_value(COUNTER_COMPONENT_ID);
let new_value = process_counter_updates(current_value, counter_updates.into_iter())?;
updater.set(ComponentData::from_parts(
COUNTER_COMPONENT_ID,
new_value.into(),
));
Ok(())
}
#[openmls_test]
fn app_data_update_book_example() {
let alice_party = CorePartyState::<Provider>::new("alice");
let bob_party = CorePartyState::<Provider>::new("bob");
let mut group_state = setup_group_with_app_data_support(&alice_party, &bob_party, ciphersuite);
let [alice, bob] = group_state.members_mut(&["alice", "bob"]);
let (proposal_message, _proposal_ref) = alice
.group
.propose_app_data_update(
&alice_party.provider,
&alice.party.signer,
COUNTER_COMPONENT_ID,
AppDataUpdateOperation::Update(CounterOperation::Increment.to_bytes().into()),
)
.expect("failed to create proposal");
let processed_proposal = bob
.group
.process_message(
&bob_party.provider,
proposal_message
.into_protocol_message()
.expect("failed to convert Proposal MlsMessageOut to ProtocolMessage"),
)
.expect("failed to process proposal");
match processed_proposal.into_content() {
ProcessedMessageContent::ProposalMessage(proposal) => {
bob.group
.store_pending_proposal(bob_party.provider.storage(), *proposal)
.expect("failed to store proposal");
}
_ => panic!("expected a proposal message"),
}
let mut commit_stage = alice
.group
.commit_builder()
.add_proposals(vec![
Proposal::AppDataUpdate(Box::new(AppDataUpdateProposal::update(
COUNTER_COMPONENT_ID,
CounterOperation::Increment.to_bytes(),
))),
])
.load_psks(alice_party.provider.storage())
.expect("failed to load PSKs");
let mut alice_updater = commit_stage.app_data_dictionary_updater();
process_app_data_proposals(&mut alice_updater, commit_stage.app_data_update_proposals())
.expect("failed to process proposals");
commit_stage.with_app_data_dictionary_updates(alice_updater.changes());
let commit_bundle = commit_stage
.build(
alice_party.provider.rand(),
alice_party.provider.crypto(),
&alice.party.signer,
|_proposal| true, )
.expect("failed to build commit")
.stage_commit(&alice_party.provider)
.expect("failed to stage commit");
let (commit_message, _welcome, _group_info) = commit_bundle.into_contents();
let commit_in: MlsMessageIn = commit_message.into();
let unverified_message = bob
.group
.unprotect_message(
&bob_party.provider,
commit_in
.into_protocol_message()
.expect("not a protocol message"),
)
.expect("failed to unprotect message");
let mut bob_updater = bob.group.app_data_dictionary_updater();
let committed_proposals = unverified_message
.committed_proposals()
.expect("not a commit");
let mut app_data_updates: Vec<AppDataUpdateProposal> = Vec::new();
for proposal_or_ref in committed_proposals.iter() {
let validated = proposal_or_ref
.clone()
.validate(
bob_party.provider.crypto(),
ciphersuite,
ProtocolVersion::Mls10,
)
.expect("invalid proposal");
let proposal: Box<Proposal> = match validated {
ProposalOrRef::Proposal(proposal) => proposal,
ProposalOrRef::Reference(reference) => {
bob.group
.proposal_store()
.proposals()
.find(|p| p.proposal_reference_ref() == &*reference)
.map(|p| Box::new(p.proposal().clone()))
.expect("proposal not found in store")
}
};
if let Proposal::AppDataUpdate(app_data_proposal) = *proposal {
app_data_updates.push(*app_data_proposal);
}
}
process_app_data_proposals(&mut bob_updater, app_data_updates.iter())
.expect("failed to process proposals");
let processed_message = bob
.group
.process_unverified_message_with_app_data_updates(
&bob_party.provider,
unverified_message,
bob_updater.changes(),
)
.expect("failed to process commit");
let staged_commit = match processed_message.into_content() {
ProcessedMessageContent::StagedCommitMessage(commit) => commit,
_ => panic!("expected a staged commit"),
};
bob.group
.merge_staged_commit(&bob_party.provider, *staged_commit)
.expect("failed to merge commit");
alice
.group
.merge_pending_commit(&alice_party.provider)
.expect("failed to merge pending commit");
assert_eq!(
alice.group.extensions().app_data_dictionary(),
bob.group.extensions().app_data_dictionary(),
"dictionaries should match"
);
let alice_dict = alice
.group
.extensions()
.app_data_dictionary()
.expect("dictionary should exist");
let counter_bytes = alice_dict
.dictionary()
.get(&COUNTER_COMPONENT_ID)
.expect("counter should exist");
let counter_value = u32::from_be_bytes(counter_bytes.try_into().expect("invalid length"));
assert_eq!(counter_value, 2, "counter should be 2 after two increments");
}
#[openmls_test]
fn app_data_update_invalid_decrement() {
let alice_party = CorePartyState::<Provider>::new("alice");
let bob_party = CorePartyState::<Provider>::new("bob");
let mut group_state = setup_group_with_app_data_support(&alice_party, &bob_party, ciphersuite);
let [alice, _bob] = group_state.members_mut(&["alice", "bob"]);
let commit_stage = alice
.group
.commit_builder()
.add_proposals(vec![Proposal::AppDataUpdate(Box::new(
AppDataUpdateProposal::update(
COUNTER_COMPONENT_ID,
CounterOperation::Decrement.to_bytes(),
),
))])
.load_psks(alice_party.provider.storage())
.expect("failed to load PSKs");
let mut alice_updater = commit_stage.app_data_dictionary_updater();
let proposals: Vec<_> = commit_stage.app_data_update_proposals().collect();
let result = process_app_data_proposals(&mut alice_updater, proposals.into_iter());
assert_eq!(
result,
Err(CounterError::Underflow),
"decrementing unset counter should fail"
);
}
#[openmls_test]
fn app_data_update_increment_then_decrement() {
let alice_party = CorePartyState::<Provider>::new("alice");
let bob_party = CorePartyState::<Provider>::new("bob");
let mut group_state = setup_group_with_app_data_support(&alice_party, &bob_party, ciphersuite);
let [alice, bob] = group_state.members_mut(&["alice", "bob"]);
{
let mut commit_stage = alice
.group
.commit_builder()
.add_proposals(vec![Proposal::AppDataUpdate(Box::new(
AppDataUpdateProposal::update(
COUNTER_COMPONENT_ID,
CounterOperation::Increment.to_bytes(),
),
))])
.load_psks(alice_party.provider.storage())
.expect("failed to load PSKs");
let mut alice_updater = commit_stage.app_data_dictionary_updater();
let proposals: Vec<_> = commit_stage.app_data_update_proposals().collect();
process_app_data_proposals(&mut alice_updater, proposals.into_iter())
.expect("failed to process");
commit_stage.with_app_data_dictionary_updates(alice_updater.changes());
let commit_bundle = commit_stage
.build(
alice_party.provider.rand(),
alice_party.provider.crypto(),
&alice.party.signer,
|_| true,
)
.expect("failed to build")
.stage_commit(&alice_party.provider)
.expect("failed to stage");
let (commit_message, _, _) = commit_bundle.into_contents();
let commit_in: MlsMessageIn = commit_message.into();
let unverified = bob
.group
.unprotect_message(
&bob_party.provider,
commit_in.into_protocol_message().unwrap(),
)
.unwrap();
let mut bob_updater = bob.group.app_data_dictionary_updater();
let committed = unverified.committed_proposals().unwrap();
let mut updates: Vec<AppDataUpdateProposal> = Vec::new();
for por in committed.iter() {
let validated = por
.clone()
.validate(
bob_party.provider.crypto(),
ciphersuite,
ProtocolVersion::Mls10,
)
.unwrap();
if let ProposalOrRef::Proposal(p) = validated {
if let Proposal::AppDataUpdate(u) = *p {
updates.push(*u);
}
}
}
process_app_data_proposals(&mut bob_updater, updates.iter()).unwrap();
let processed = bob
.group
.process_unverified_message_with_app_data_updates(
&bob_party.provider,
unverified,
bob_updater.changes(),
)
.unwrap();
if let ProcessedMessageContent::StagedCommitMessage(sc) = processed.into_content() {
bob.group
.merge_staged_commit(&bob_party.provider, *sc)
.unwrap();
}
alice
.group
.merge_pending_commit(&alice_party.provider)
.unwrap();
}
let dict = alice.group.extensions().app_data_dictionary().unwrap();
let val = u32::from_be_bytes(
dict.dictionary()
.get(&COUNTER_COMPONENT_ID)
.unwrap()
.try_into()
.unwrap(),
);
assert_eq!(val, 1);
{
let mut commit_stage = alice
.group
.commit_builder()
.add_proposals(vec![Proposal::AppDataUpdate(Box::new(
AppDataUpdateProposal::update(
COUNTER_COMPONENT_ID,
CounterOperation::Decrement.to_bytes(),
),
))])
.load_psks(alice_party.provider.storage())
.expect("failed to load PSKs");
let mut alice_updater = commit_stage.app_data_dictionary_updater();
let proposals: Vec<_> = commit_stage.app_data_update_proposals().collect();
process_app_data_proposals(&mut alice_updater, proposals.into_iter())
.expect("decrement should succeed");
commit_stage.with_app_data_dictionary_updates(alice_updater.changes());
let commit_bundle = commit_stage
.build(
alice_party.provider.rand(),
alice_party.provider.crypto(),
&alice.party.signer,
|_| true,
)
.expect("failed to build")
.stage_commit(&alice_party.provider)
.expect("failed to stage");
let (commit_message, _, _) = commit_bundle.into_contents();
let commit_in: MlsMessageIn = commit_message.into();
let unverified = bob
.group
.unprotect_message(
&bob_party.provider,
commit_in.into_protocol_message().unwrap(),
)
.unwrap();
let mut bob_updater = bob.group.app_data_dictionary_updater();
let committed = unverified.committed_proposals().unwrap();
let mut updates: Vec<AppDataUpdateProposal> = Vec::new();
for por in committed.iter() {
let validated = por
.clone()
.validate(
bob_party.provider.crypto(),
ciphersuite,
ProtocolVersion::Mls10,
)
.unwrap();
if let ProposalOrRef::Proposal(p) = validated {
if let Proposal::AppDataUpdate(u) = *p {
updates.push(*u);
}
}
}
process_app_data_proposals(&mut bob_updater, updates.iter()).unwrap();
let processed = bob
.group
.process_unverified_message_with_app_data_updates(
&bob_party.provider,
unverified,
bob_updater.changes(),
)
.unwrap();
if let ProcessedMessageContent::StagedCommitMessage(sc) = processed.into_content() {
bob.group
.merge_staged_commit(&bob_party.provider, *sc)
.unwrap();
}
alice
.group
.merge_pending_commit(&alice_party.provider)
.unwrap();
}
let dict = alice.group.extensions().app_data_dictionary().unwrap();
let val = u32::from_be_bytes(
dict.dictionary()
.get(&COUNTER_COMPONENT_ID)
.unwrap()
.try_into()
.unwrap(),
);
assert_eq!(val, 0);
assert_eq!(
alice.group.extensions().app_data_dictionary(),
bob.group.extensions().app_data_dictionary()
);
}