use gbp_stack::{
ControlOpcode, GbpFlags, GroupNode, MlsContext, NodeState, ProcessedKind, StreamType,
label_for,
};
use openmls::prelude::DeserializeBytes as _;
use openmls::prelude::{KeyPackage, KeyPackageIn, ProtocolVersion};
use openmls_traits::OpenMlsProvider as _;
fn validated_kp(ctx: &MlsContext, raw: &[u8]) -> KeyPackage {
let kp_in = KeyPackageIn::tls_deserialize_exact_bytes(raw).expect("kp parse");
kp_in
.validate(ctx.provider.crypto(), ProtocolVersion::Mls10)
.expect("kp validate")
}
#[test]
fn two_party_add_completes_full_handshake() {
let (mut alice, _alice_kp) = MlsContext::new_member(b"alice").unwrap();
let (mut bob, bob_kp_bundle) = MlsContext::new_member(b"bob").unwrap();
let bob_kp_bytes = openmls::prelude::tls_codec::Serialize::tls_serialize_detached(
bob_kp_bundle.key_package(),
)
.unwrap();
let validated = validated_kp(&alice, &bob_kp_bytes);
let (commit_bytes, welcome_bytes) = alice.invite_full(&[validated]).unwrap();
assert_eq!(alice.epoch(), 0, "invite_full must NOT advance epoch");
assert!(!commit_bytes.is_empty());
assert!(!welcome_bytes.is_empty());
bob.accept_welcome(&welcome_bytes).unwrap();
assert_eq!(bob.epoch(), 1);
assert_eq!(bob.group_id_16(), alice.group_id_16());
alice.finalize_pending_commit().unwrap();
assert_eq!(alice.epoch(), 1);
let mut a_node = GroupNode::new(1, alice.group_id_16());
let mut b_node = GroupNode::new(2, bob.group_id_16());
a_node.bootstrap_as_creator(0);
b_node.bootstrap_as_joiner(0, 1);
assert_eq!(b_node.pending_transition_id, 1);
let exec = a_node
.send_control(&mut alice, 0, ControlOpcode::ExecuteTransition, 1, 7, vec![])
.unwrap();
a_node.apply_transition(1);
let evs = b_node.on_wire(&mut bob, &exec.wire).unwrap();
let errs: Vec<u16> = evs.iter().filter_map(|e| match e {
gbp_stack::Event::Error { code, .. } => Some(*code),
_ => None,
}).collect();
assert!(errs.is_empty(), "got errors during EXECUTE delivery: {errs:?}");
assert_eq!(a_node.last_transition_id, 1);
assert_eq!(b_node.last_transition_id, 1);
assert_eq!(a_node.current_epoch, 1);
assert_eq!(b_node.current_epoch, 1);
assert_eq!(a_node.state, NodeState::Active);
assert_eq!(b_node.state, NodeState::Active);
let sid = a_node.member_stream_id(2);
let msg = a_node
.send_payload(
&mut alice,
2,
StreamType::Text,
sid,
GbpFlags::ordered_reliable_ack(),
b"hi bob",
)
.unwrap();
let recv = b_node.on_wire(&mut bob, &msg.wire).unwrap();
let pr = recv
.into_iter()
.find_map(|e| match e {
gbp_stack::Event::PayloadReceived(p) => Some(p),
_ => None,
})
.expect("payload");
assert_eq!(pr.plaintext, b"hi bob");
}
#[test]
fn abort_rolls_back_pending_commit() {
let (mut alice, _) = MlsContext::new_member(b"alice").unwrap();
let (_bob, bob_kp_bundle) = MlsContext::new_member(b"bob").unwrap();
let bob_kp_bytes = openmls::prelude::tls_codec::Serialize::tls_serialize_detached(
bob_kp_bundle.key_package(),
)
.unwrap();
let validated = validated_kp(&alice, &bob_kp_bytes);
let _ = alice.invite_full(&[validated]).unwrap();
assert_eq!(alice.epoch(), 0);
alice.clear_pending_commit().unwrap();
assert_eq!(alice.epoch(), 0, "epoch must stay at 0 after abort");
}
#[test]
fn process_message_on_existing_member_advances_epoch() {
let (mut alice, _) = MlsContext::new_member(b"alice").unwrap();
let (mut bob, bob_kp_bundle) = MlsContext::new_member(b"bob").unwrap();
let bob_kp_bytes = openmls::prelude::tls_codec::Serialize::tls_serialize_detached(
bob_kp_bundle.key_package(),
)
.unwrap();
let v_bob = validated_kp(&alice, &bob_kp_bytes);
let (_commit1, welcome_b) = alice.invite_full(&[v_bob]).unwrap();
bob.accept_welcome(&welcome_b).unwrap();
alice.finalize_pending_commit().unwrap();
assert_eq!(alice.epoch(), 1);
assert_eq!(bob.epoch(), 1);
let (_carol, carol_kp_bundle) = MlsContext::new_member(b"carol").unwrap();
let carol_kp_bytes = openmls::prelude::tls_codec::Serialize::tls_serialize_detached(
carol_kp_bundle.key_package(),
)
.unwrap();
let v_carol = validated_kp(&alice, &carol_kp_bytes);
let (commit2, _welcome_c) = alice.invite_full(&[v_carol]).unwrap();
assert_eq!(bob.epoch(), 1);
let kind = bob.process_message(&commit2).unwrap();
assert_eq!(kind, ProcessedKind::Commit);
assert_eq!(bob.epoch(), 1, "staged but not merged");
bob.finalize_pending_commit().unwrap();
assert_eq!(bob.epoch(), 2, "finalize merges the staged commit");
alice.finalize_pending_commit().unwrap();
assert_eq!(alice.epoch(), 2);
}
#[test]
fn aead_round_trips_under_label() {
let (alice, _) = MlsContext::new_member(b"alice").unwrap();
let label = label_for(StreamType::Text);
let pt = b"the quick brown fox";
let ct = alice.seal(label, 1, pt).unwrap();
let dec = alice.open(label, 1, &ct).unwrap();
assert_eq!(dec, pt);
assert!(alice.open(label, 2, &ct).is_err());
}