use super::{
proposals::{ProposalStore, QueuedProposal},
staged_commit::StagedCommit,
};
use crate::{
ciphersuite::{hash_ref::KeyPackageRef, signable::Signable},
credentials::{Credential, CredentialBundle},
error::LibraryError,
framing::*,
group::*,
key_packages::{KeyPackage, KeyPackageBundle, KeyPackageBundlePayload},
messages::{proposals::*, Welcome},
schedule::ResumptionSecret,
treesync::Node,
};
use openmls_traits::{key_store::OpenMlsKeyStore, types::Ciphersuite, OpenMlsCryptoProvider};
use std::io::{Error, Read, Write};
mod application;
mod creation;
mod exporting;
mod resumption;
mod ser;
mod updates;
use config::*;
use errors::*;
use resumption::*;
use ser::*;
pub(crate) mod config;
pub(crate) mod errors;
pub(crate) mod membership;
pub(crate) mod processing;
#[cfg(test)]
mod test_mls_group;
#[derive(Debug, Serialize, Deserialize)]
pub enum PendingCommitState {
Member(StagedCommit),
External(StagedCommit),
}
impl PendingCommitState {
pub(crate) fn staged_commit(&self) -> &StagedCommit {
match self {
PendingCommitState::Member(pc) => pc,
PendingCommitState::External(pc) => pc,
}
}
}
impl From<PendingCommitState> for StagedCommit {
fn from(pcs: PendingCommitState) -> Self {
match pcs {
PendingCommitState::Member(pc) => pc,
PendingCommitState::External(pc) => pc,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum MlsGroupState {
PendingCommit(Box<PendingCommitState>),
Operational,
Inactive,
}
#[derive(Debug)]
pub struct MlsGroup {
mls_group_config: MlsGroupConfig,
group: CoreGroup,
proposal_store: ProposalStore,
own_kpbs: Vec<KeyPackageBundle>,
aad: Vec<u8>,
resumption_secret_store: ResumptionSecretStore,
group_state: MlsGroupState,
state_changed: InnerState,
}
impl MlsGroup {
pub fn configuration(&self) -> &MlsGroupConfig {
&self.mls_group_config
}
pub fn set_configuration(&mut self, mls_group_config: &MlsGroupConfig) {
self.mls_group_config = mls_group_config.clone();
self.flag_state_change();
}
pub fn aad(&self) -> &[u8] {
&self.aad
}
pub fn set_aad(&mut self, aad: &[u8]) {
self.aad = aad.to_vec();
self.flag_state_change();
}
pub fn ciphersuite(&self) -> Ciphersuite {
self.group.ciphersuite()
}
pub fn is_active(&self) -> bool {
!matches!(self.group_state, MlsGroupState::Inactive)
}
pub fn credential(&self) -> Result<&Credential, MlsGroupStateError> {
if !self.is_active() {
return Err(MlsGroupStateError::UseAfterEviction);
}
let tree = self.group.treesync();
Ok(tree
.own_leaf_node()
.map_err(|_| LibraryError::custom("Own leaf node missing"))?
.key_package()
.credential())
}
pub fn key_package_ref(&self) -> Option<&KeyPackageRef> {
self.group.key_package_ref()
}
pub fn group_id(&self) -> &GroupId {
self.group.group_id()
}
pub fn epoch(&self) -> GroupEpoch {
self.group.context().epoch()
}
pub fn pending_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
self.proposal_store.proposals()
}
pub fn pending_commit(&self) -> Option<&StagedCommit> {
match self.group_state {
MlsGroupState::PendingCommit(ref pending_commit_state) => {
Some(pending_commit_state.staged_commit())
}
MlsGroupState::Operational => None,
MlsGroupState::Inactive => None,
}
}
pub fn clear_pending_commit(&mut self) {
match self.group_state {
MlsGroupState::PendingCommit(ref pending_commit_state) => {
if let PendingCommitState::Member(_) = **pending_commit_state {
self.group_state = MlsGroupState::Operational
}
}
MlsGroupState::Operational | MlsGroupState::Inactive => (),
}
}
pub fn load<R: Read>(reader: R) -> Result<MlsGroup, Error> {
let serialized_mls_group: SerializedMlsGroup = serde_json::from_reader(reader)?;
Ok(serialized_mls_group.into_mls_group())
}
pub fn save<W: Write>(&mut self, writer: &mut W) -> Result<(), Error> {
let serialized_mls_group = serde_json::to_string_pretty(self)?;
writer.write_all(&serialized_mls_group.into_bytes())?;
self.state_changed = InnerState::Persisted;
Ok(())
}
pub fn state_changed(&self) -> InnerState {
self.state_changed
}
pub fn export_ratchet_tree(&self) -> Vec<Option<Node>> {
self.group.treesync().export_nodes()
}
}
impl MlsGroup {
fn plaintext_to_mls_message(
&mut self,
plaintext: MlsPlaintext,
backend: &impl OpenMlsCryptoProvider,
) -> Result<MlsMessageOut, LibraryError> {
let msg = match self.configuration().wire_format_policy().outgoing() {
OutgoingWireFormatPolicy::AlwaysPlaintext => MlsMessageOut::from(plaintext),
OutgoingWireFormatPolicy::AlwaysCiphertext => {
let ciphertext = self
.group
.encrypt(plaintext, self.configuration().padding_size(), backend)
.map_err(|_| LibraryError::custom("Malformed plaintext"))?;
MlsMessageOut::from(ciphertext)
}
};
Ok(msg)
}
fn flag_state_change(&mut self) {
self.state_changed = InnerState::Changed;
}
fn framing_parameters(&self) -> FramingParameters {
FramingParameters::new(
&self.aad,
self.mls_group_config.wire_format_policy().outgoing(),
)
}
fn is_operational(&self) -> Result<(), MlsGroupStateError> {
match self.group_state {
MlsGroupState::PendingCommit(_) => Err(MlsGroupStateError::PendingCommit),
MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction),
MlsGroupState::Operational => Ok(()),
}
}
}
impl MlsGroup {
#[cfg(any(feature = "test-utils", test))]
pub fn export_group_context(&self) -> &GroupContext {
self.group.context()
}
#[cfg(any(feature = "test-utils", test))]
pub fn tree_hash(&self) -> &[u8] {
self.group.treesync().tree_hash()
}
#[cfg(any(feature = "test-utils", test))]
pub fn print_tree(&self, message: &str) {
self.group.print_tree(message)
}
#[cfg(test)]
pub(crate) fn group(&self) -> &CoreGroup {
&self.group
}
#[cfg(test)]
pub(crate) fn clear_pending_proposals(&mut self) {
self.proposal_store.empty()
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum InnerState {
Changed,
Persisted,
}