use super::*;
use crate::{
extensions::Extensions,
key_packages::Lifetime,
tree::sender_ratchet::SenderRatchetConfiguration,
treesync::{errors::LeafNodeValidationError, node::leaf_node::Capabilities},
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct MlsGroupJoinConfig {
pub(crate) wire_format_policy: WireFormatPolicy,
pub(crate) padding_size: usize,
pub(crate) max_past_epochs: usize,
pub(crate) number_of_resumption_psks: usize,
pub(crate) use_ratchet_tree_extension: bool,
pub(crate) sender_ratchet_configuration: SenderRatchetConfiguration,
}
impl MlsGroupJoinConfig {
pub fn builder() -> MlsGroupJoinConfigBuilder {
MlsGroupJoinConfigBuilder::new()
}
pub fn wire_format_policy(&self) -> WireFormatPolicy {
self.wire_format_policy
}
pub fn padding_size(&self) -> usize {
self.padding_size
}
pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
&self.sender_ratchet_configuration
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct MlsGroupCreateConfig {
pub(crate) capabilities: Capabilities,
pub(crate) lifetime: Lifetime,
pub(crate) ciphersuite: Ciphersuite,
pub(crate) join_config: MlsGroupJoinConfig,
pub(crate) group_context_extensions: Extensions<GroupContext>,
pub(crate) leaf_node_extensions: Extensions<LeafNode>,
}
impl Default for MlsGroupCreateConfig {
fn default() -> Self {
Self {
capabilities: Capabilities::default(),
lifetime: Lifetime::default(),
ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
join_config: MlsGroupJoinConfig::default(),
group_context_extensions: Extensions::default(),
leaf_node_extensions: Extensions::default(),
}
}
}
#[derive(Default)]
pub struct MlsGroupJoinConfigBuilder {
join_config: MlsGroupJoinConfig,
}
impl MlsGroupJoinConfigBuilder {
fn new() -> Self {
Self {
join_config: MlsGroupJoinConfig::default(),
}
}
pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
self.join_config.wire_format_policy = wire_format_policy;
self
}
pub fn padding_size(mut self, padding_size: usize) -> Self {
self.join_config.padding_size = padding_size;
self
}
pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
self.join_config.max_past_epochs = max_past_epochs;
self
}
pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
self.join_config.number_of_resumption_psks = number_of_resumption_psks;
self
}
pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
self.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
self
}
pub fn sender_ratchet_configuration(
mut self,
sender_ratchet_configuration: SenderRatchetConfiguration,
) -> Self {
self.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
self
}
pub fn build(self) -> MlsGroupJoinConfig {
self.join_config
}
}
impl MlsGroupCreateConfig {
pub fn builder() -> MlsGroupCreateConfigBuilder {
MlsGroupCreateConfigBuilder::new()
}
pub fn wire_format_policy(&self) -> WireFormatPolicy {
self.join_config.wire_format_policy
}
pub fn padding_size(&self) -> usize {
self.join_config.padding_size
}
pub fn max_past_epochs(&self) -> usize {
self.join_config.max_past_epochs
}
pub fn number_of_resumption_psks(&self) -> usize {
self.join_config.number_of_resumption_psks
}
pub fn use_ratchet_tree_extension(&self) -> bool {
self.join_config.use_ratchet_tree_extension
}
pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
&self.join_config.sender_ratchet_configuration
}
pub fn group_context_extensions(&self) -> &Extensions<GroupContext> {
&self.group_context_extensions
}
pub fn lifetime(&self) -> &Lifetime {
&self.lifetime
}
pub fn ciphersuite(&self) -> Ciphersuite {
self.ciphersuite
}
#[cfg(any(feature = "test-utils", test))]
pub fn test_default(ciphersuite: Ciphersuite) -> Self {
Self::builder()
.wire_format_policy(WireFormatPolicy::new(
OutgoingWireFormatPolicy::AlwaysPlaintext,
IncomingWireFormatPolicy::Mixed,
))
.ciphersuite(ciphersuite)
.build()
}
pub fn join_config(&self) -> &MlsGroupJoinConfig {
&self.join_config
}
}
#[derive(Default, Debug)]
pub struct MlsGroupCreateConfigBuilder {
config: MlsGroupCreateConfig,
}
impl MlsGroupCreateConfigBuilder {
fn new() -> Self {
MlsGroupCreateConfigBuilder {
config: MlsGroupCreateConfig::default(),
}
}
pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
self.config.join_config.wire_format_policy = wire_format_policy;
self
}
pub fn padding_size(mut self, padding_size: usize) -> Self {
self.config.join_config.padding_size = padding_size;
self
}
pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
self.config.join_config.max_past_epochs = max_past_epochs;
self
}
pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
self.config.join_config.number_of_resumption_psks = number_of_resumption_psks;
self
}
pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
self.config.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
self
}
pub fn capabilities(mut self, capabilities: Capabilities) -> Self {
self.config.capabilities = capabilities;
self
}
pub fn sender_ratchet_configuration(
mut self,
sender_ratchet_configuration: SenderRatchetConfiguration,
) -> Self {
self.config.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
self
}
pub fn lifetime(mut self, lifetime: Lifetime) -> Self {
self.config.lifetime = lifetime;
self
}
pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self {
self.config.ciphersuite = ciphersuite;
self
}
pub fn with_group_context_extensions(mut self, extensions: Extensions<GroupContext>) -> Self {
self.config.group_context_extensions = extensions;
self
}
pub fn with_leaf_node_extensions(
mut self,
extensions: Extensions<LeafNode>,
) -> Result<Self, LeafNodeValidationError> {
if !self.config.capabilities.contains_extensions(&extensions) {
return Err(LeafNodeValidationError::ExtensionsNotInCapabilities);
}
self.config.leaf_node_extensions = extensions;
Ok(self)
}
pub fn build(self) -> MlsGroupCreateConfig {
self.config
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum IncomingWireFormatPolicy {
AlwaysCiphertext,
AlwaysPlaintext,
Mixed,
}
impl IncomingWireFormatPolicy {
pub(crate) fn is_compatible_with(&self, wire_format: WireFormat) -> bool {
match self {
IncomingWireFormatPolicy::AlwaysCiphertext => wire_format == WireFormat::PrivateMessage,
IncomingWireFormatPolicy::AlwaysPlaintext => wire_format == WireFormat::PublicMessage,
IncomingWireFormatPolicy::Mixed => {
wire_format == WireFormat::PrivateMessage
|| wire_format == WireFormat::PublicMessage
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum OutgoingWireFormatPolicy {
AlwaysCiphertext,
AlwaysPlaintext,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct WireFormatPolicy {
outgoing: OutgoingWireFormatPolicy,
incoming: IncomingWireFormatPolicy,
}
impl WireFormatPolicy {
#[cfg(any(feature = "test-utils", test))]
pub(crate) fn new(
outgoing: OutgoingWireFormatPolicy,
incoming: IncomingWireFormatPolicy,
) -> Self {
Self { outgoing, incoming }
}
pub fn outgoing(&self) -> OutgoingWireFormatPolicy {
self.outgoing
}
pub fn incoming(&self) -> IncomingWireFormatPolicy {
self.incoming
}
}
impl Default for WireFormatPolicy {
fn default() -> Self {
PURE_CIPHERTEXT_WIRE_FORMAT_POLICY
}
}
impl From<OutgoingWireFormatPolicy> for WireFormat {
fn from(outgoing: OutgoingWireFormatPolicy) -> Self {
match outgoing {
OutgoingWireFormatPolicy::AlwaysCiphertext => WireFormat::PrivateMessage,
OutgoingWireFormatPolicy::AlwaysPlaintext => WireFormat::PublicMessage,
}
}
}
pub const WIRE_FORMAT_POLICIES: [WireFormatPolicy; 4] = [
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
PURE_CIPHERTEXT_WIRE_FORMAT_POLICY,
MIXED_PLAINTEXT_WIRE_FORMAT_POLICY,
MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY,
];
pub const PURE_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
incoming: IncomingWireFormatPolicy::AlwaysPlaintext,
};
pub const PURE_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
incoming: IncomingWireFormatPolicy::AlwaysCiphertext,
};
pub const MIXED_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
incoming: IncomingWireFormatPolicy::Mixed,
};
pub const MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
incoming: IncomingWireFormatPolicy::Mixed,
};