use ::serde::Deserialize;
use openmls_traits::OpenMlsProvider;
use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize};
use crate::{
binary_tree::LeafNodeIndex,
ciphersuite::{Mac, Secret},
extensions::Extensions,
framing::Sender,
group::{
proposal_store::{ProposalQueue, QueuedProposal},
GroupContext, GroupEpoch, GroupId, PublicGroup,
},
messages::{proposals::Proposal, proposals_in::ProposalIn, ConfirmationTag},
test_utils::*,
treesync::{node::NodeIn, RatchetTree, RatchetTreeIn, TreeSync},
};
#[derive(Deserialize)]
struct TestElement {
cipher_suite: u16,
#[serde(with = "hex")]
tree_before: Vec<u8>,
#[serde(with = "hex")]
proposal: Vec<u8>,
proposal_sender: u32,
#[serde(with = "hex")]
tree_after: Vec<u8>,
#[serde(with = "hex")]
tree_hash_after: Vec<u8>,
#[serde(with = "hex")]
tree_hash_before: Vec<u8>,
}
fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result<(), String> {
let ciphersuite = Ciphersuite::try_from(test.cipher_suite).unwrap();
let group_id = GroupId::random(provider.rand());
let nodes = Vec::<Option<NodeIn>>::tls_deserialize_exact(test.tree_before).unwrap();
let ratchet_tree = RatchetTree::from(RatchetTreeIn::from_nodes(nodes));
let tree_before = TreeSync::from_ratchet_tree(provider.crypto(), ciphersuite, ratchet_tree)
.map_err(|e| format!("Error while creating tree sync: {e:?}"))?;
let tree_hash_before = tree_before.tree_hash();
assert_eq!(test.tree_hash_before, tree_hash_before);
let group_context = GroupContext::new(
ciphersuite,
group_id,
GroupEpoch::from(0),
vec![],
vec![],
Extensions::empty(),
);
let initial_confirmation_tag = ConfirmationTag(
Mac::new(
provider.crypto(),
ciphersuite,
&Secret::random(ciphersuite, provider.rand()).unwrap(),
&[],
)
.unwrap(),
);
let mut group = PublicGroup::new(
provider.crypto(),
tree_before,
group_context,
initial_confirmation_tag,
)
.unwrap();
let mut proposal_queue = ProposalQueue::default();
let proposal_in = ProposalIn::tls_deserialize_exact(test.proposal).unwrap();
let proposal = Proposal::from(proposal_in);
let queued_proposal = QueuedProposal::from_proposal_and_sender(
ciphersuite,
provider.crypto(),
proposal,
&Sender::Member(LeafNodeIndex::new(test.proposal_sender)),
)
.unwrap();
proposal_queue.add(queued_proposal);
let mut diff = group.empty_diff();
let apply_proposal_values = diff.apply_proposals(&proposal_queue, None).unwrap();
diff.update_group_context(provider.crypto(), apply_proposal_values.extensions.clone())
.unwrap();
let staged_diff = diff
.into_staged_diff(provider.crypto(), ciphersuite)
.unwrap();
group.merge_diff(staged_diff);
let tree_after = group
.export_ratchet_tree()
.tls_serialize_detached()
.unwrap();
assert_eq!(test.tree_after, tree_after);
let tree_hash_after = group.group_context().tree_hash();
assert_eq!(test.tree_hash_after, tree_hash_after);
Ok(())
}
#[openmls_test::openmls_test]
fn read_test_vectors_tree_operations() {
let provider = &Provider::default();
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");
let tests: Vec<TestElement> = read_json!("../../../../test_vectors/tree-operations.json");
for test_vector in tests {
match run_test_vector(test_vector, provider) {
Ok(_) => {}
Err(e) => panic!("Error while checking tree operations test vector.\n{e:?}"),
}
}
log::trace!("Finished test vector verification");
}