use openmls_rust_crypto::MemoryStorage;
use prost::Message;
use sha2::{Digest, Sha256};
use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
use tracing::info;
use crate::core::{
ProposalId, error::CoreError, group_handle::GroupHandle, types::ProcessResult,
types::invitation_from_bytes,
};
use crate::ds::{APP_MSG_SUBTOPIC, OutboundPacket, WELCOME_SUBTOPIC};
use crate::mls_crypto::{
CommitResult, DeMlsStorage, DecryptResult, GroupUpdate, KeyPackageBytes, MlsService,
key_package_bytes_from_json,
};
use crate::protos::de_mls::messages::v1::{
AppMessage, BatchProposalsMessage, GroupUpdateRequest, InviteMember, UserKeyPackage,
WelcomeMessage, app_message, group_update_request, welcome_message,
};
pub fn create_group<S>(name: &str, mls: &MlsService<S>) -> Result<GroupHandle, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
mls.create_group(name)?;
Ok(GroupHandle::new_as_creator(name))
}
pub fn prepare_to_join(name: &str) -> GroupHandle {
GroupHandle::new_for_join(name)
}
pub fn join_group_from_invite<S>(
handle: &mut GroupHandle,
welcome_bytes: &[u8],
mls: &MlsService<S>,
) -> Result<String, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
let group_name = mls.join_group(welcome_bytes)?;
handle.set_mls_initialized();
Ok(group_name)
}
pub fn become_steward(handle: &mut GroupHandle) {
handle.become_steward();
}
pub fn resign_steward(handle: &mut GroupHandle) {
handle.resign_steward();
}
pub fn build_message<S>(
handle: &GroupHandle,
mls: &MlsService<S>,
app_msg: &AppMessage,
) -> Result<OutboundPacket, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
if !handle.is_mls_initialized() {
return Err(CoreError::MlsGroupNotInitialized);
}
let message_out = mls.encrypt(handle.group_name(), &app_msg.encode_to_vec())?;
Ok(OutboundPacket::new(
message_out,
APP_MSG_SUBTOPIC,
handle.group_name(),
handle.app_id(),
))
}
pub fn build_key_package_message<S>(
handle: &GroupHandle,
mls: &MlsService<S>,
) -> Result<OutboundPacket, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
let key_package = mls.generate_key_package()?;
let welcome_msg: WelcomeMessage = UserKeyPackage {
key_package_bytes: key_package.as_bytes().to_vec(),
}
.into();
Ok(OutboundPacket::new(
welcome_msg.encode_to_vec(),
WELCOME_SUBTOPIC,
handle.group_name(),
handle.app_id(),
))
}
pub fn process_inbound<S>(
handle: &mut GroupHandle,
payload: &[u8],
subtopic: &str,
mls: &MlsService<S>,
) -> Result<ProcessResult, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
match subtopic {
WELCOME_SUBTOPIC => process_welcome_subtopic(handle, payload, mls),
APP_MSG_SUBTOPIC => process_app_subtopic(handle, payload, mls),
_ => Err(CoreError::InvalidSubtopic(subtopic.to_string())),
}
}
fn process_welcome_subtopic<S>(
handle: &mut GroupHandle,
payload: &[u8],
mls: &MlsService<S>,
) -> Result<ProcessResult, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
let welcome_msg = WelcomeMessage::decode(payload)?;
match welcome_msg.payload {
Some(welcome_message::Payload::UserKeyPackage(user_kp)) => {
if handle.is_steward() {
info!(
"Steward received key package for group {}",
handle.group_name()
);
let (key_package_bytes, identity) =
key_package_bytes_from_json(user_kp.key_package_bytes)?;
let gur = GroupUpdateRequest {
payload: Some(group_update_request::Payload::InviteMember(InviteMember {
key_package_bytes,
identity,
})),
};
return Ok(ProcessResult::GetUpdateRequest(gur));
}
Ok(ProcessResult::Noop)
}
Some(welcome_message::Payload::InvitationToJoin(invitation)) => {
if handle.is_steward() || handle.is_mls_initialized() {
return Ok(ProcessResult::Noop);
}
if mls.is_welcome_for_us(&invitation.mls_message_out_bytes)? {
let group_name = mls.join_group(&invitation.mls_message_out_bytes)?;
handle.set_mls_initialized();
info!(
"[process_welcome_subtopic]: Joined group {}",
handle.group_name()
);
return Ok(ProcessResult::JoinedGroup(group_name));
}
Ok(ProcessResult::Noop)
}
None => Ok(ProcessResult::Noop),
}
}
fn process_app_subtopic<S>(
handle: &mut GroupHandle,
payload: &[u8],
mls: &MlsService<S>,
) -> Result<ProcessResult, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
if !handle.is_mls_initialized() {
return Ok(ProcessResult::Noop);
}
if let Ok(app_message) = AppMessage::decode(payload)
&& let Some(app_message::Payload::BatchProposalsMessage(batch_msg)) = app_message.payload
{
return process_batch_proposals(handle, batch_msg, mls);
}
let res = mls.decrypt(handle.group_name(), payload)?;
match res {
DecryptResult::Application(app_bytes) => AppMessage::decode(app_bytes.as_ref())?.try_into(),
DecryptResult::Removed => Ok(ProcessResult::LeaveGroup),
DecryptResult::ProposalStored | DecryptResult::CommitProcessed | DecryptResult::Ignored => {
Ok(ProcessResult::Noop)
}
}
}
fn compute_proposals_digest(proposals: &HashMap<ProposalId, GroupUpdateRequest>) -> Vec<u8> {
let sorted: BTreeMap<_, _> = proposals.iter().collect();
let mut hasher = Sha256::new();
for (&id, req) in &sorted {
hasher.update(id.to_le_bytes());
hasher.update(req.encode_to_vec());
}
hasher.finalize().to_vec()
}
fn process_batch_proposals<S>(
handle: &mut GroupHandle,
batch_msg: BatchProposalsMessage,
mls: &MlsService<S>,
) -> Result<ProcessResult, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
let group_name = handle.group_name().to_owned();
let local_proposals = handle.approved_proposals();
let local_ids: BTreeSet<ProposalId> = local_proposals.keys().copied().collect();
let batch_ids: BTreeSet<ProposalId> = batch_msg.proposal_ids.iter().copied().collect();
if local_ids != batch_ids {
tracing::warn!(
"Rejecting batch for group {}: proposal ID set mismatch \
(local={:?}, batch={:?})",
group_name,
local_ids,
batch_ids,
);
return Ok(ProcessResult::Noop);
}
if local_ids.is_empty() {
tracing::warn!(
"Rejecting batch for group {}: no approved proposals",
group_name,
);
return Ok(ProcessResult::Noop);
}
let local_digest = compute_proposals_digest(&local_proposals);
if batch_msg.proposals_digest != local_digest {
tracing::warn!(
"Rejecting batch for group {}: proposals digest mismatch",
group_name,
);
return Ok(ProcessResult::Noop);
}
if batch_msg.mls_proposals.len() != local_ids.len() {
tracing::warn!(
"Rejecting batch for group {}: proposal count ({}) \
does not match MLS payload count ({})",
group_name,
local_ids.len(),
batch_msg.mls_proposals.len(),
);
return Ok(ProcessResult::Noop);
}
for (i, proposal_bytes) in batch_msg.mls_proposals.iter().enumerate() {
match mls.decrypt(&group_name, proposal_bytes)? {
DecryptResult::ProposalStored => {}
other => {
tracing::warn!(
"Rejecting batch for group {}: proposal {} \
returned {:?}, expected ProposalStored",
group_name,
i,
other,
);
return Ok(ProcessResult::Noop);
}
}
}
let res = mls.decrypt(&group_name, &batch_msg.commit_message)?;
match res {
DecryptResult::CommitProcessed => {
handle.clear_approved_proposals();
Ok(ProcessResult::GroupUpdated)
}
DecryptResult::Removed => {
handle.clear_approved_proposals();
Ok(ProcessResult::LeaveGroup)
}
DecryptResult::Application(app_bytes) => {
handle.clear_approved_proposals();
AppMessage::decode(app_bytes.as_ref())?.try_into()
}
other => {
tracing::warn!(
"Unexpected commit result for group {}: {:?}, \
keeping approved proposals",
group_name,
other,
);
Ok(ProcessResult::Noop)
}
}
}
pub fn approved_proposals_count(handle: &GroupHandle) -> usize {
handle.approved_proposals_count()
}
pub fn approved_proposals(handle: &GroupHandle) -> HashMap<ProposalId, GroupUpdateRequest> {
handle.approved_proposals()
}
pub fn epoch_history(handle: &GroupHandle) -> &VecDeque<HashMap<ProposalId, GroupUpdateRequest>> {
handle.epoch_history()
}
pub fn create_batch_proposals<S>(
handle: &mut GroupHandle,
mls: &MlsService<S>,
) -> Result<Vec<OutboundPacket>, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
if !handle.is_steward() {
return Err(CoreError::StewardNotSet);
}
let proposals = handle.approved_proposals();
if proposals.is_empty() {
return Err(CoreError::NoProposals);
}
if !handle.is_mls_initialized() {
return Err(CoreError::MlsGroupNotInitialized);
}
let proposal_ids: Vec<u32> = proposals.keys().copied().collect();
let proposals_digest = compute_proposals_digest(&proposals);
let mut updates = Vec::with_capacity(proposals.len());
for (_, proposal) in proposals {
match proposal.payload {
Some(group_update_request::Payload::InviteMember(im)) => {
updates.push(GroupUpdate::Add(KeyPackageBytes::new(
im.key_package_bytes,
im.identity,
)));
}
Some(group_update_request::Payload::RemoveMember(identity)) => {
updates.push(GroupUpdate::Remove(identity.identity));
}
None => return Err(CoreError::InvalidGroupUpdateRequest),
}
}
let CommitResult {
proposals: mls_proposals,
commit,
welcome,
} = mls.commit(handle.group_name(), &updates)?;
let batch_msg: AppMessage = BatchProposalsMessage {
group_name: handle.group_name_bytes().to_vec(),
mls_proposals,
commit_message: commit,
proposal_ids,
proposals_digest,
}
.into();
let batch_packet = OutboundPacket::new(
batch_msg.encode_to_vec(),
APP_MSG_SUBTOPIC,
handle.group_name(),
handle.app_id(),
);
let mut messages = vec![batch_packet];
if let Some(welcome_bytes) = welcome {
let welcome_msg: WelcomeMessage = invitation_from_bytes(welcome_bytes);
let welcome_packet = OutboundPacket::new(
welcome_msg.encode_to_vec(),
WELCOME_SUBTOPIC,
handle.group_name(),
handle.app_id(),
);
messages.push(welcome_packet);
}
handle.clear_approved_proposals();
Ok(messages)
}
pub fn group_members<S>(
handle: &GroupHandle,
mls: &MlsService<S>,
) -> Result<Vec<Vec<u8>>, CoreError>
where
S: DeMlsStorage<MlsStorage = MemoryStorage>,
{
if !handle.is_mls_initialized() {
return Err(CoreError::MlsGroupNotInitialized);
}
let members = mls.members(handle.group_name())?;
Ok(members)
}