1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::group::{proposal_filter::ProposalBundle, Roster};

#[cfg(feature = "private_message")]
use crate::{
    group::{padding::PaddingMode, Sender},
    WireFormat,
};

use alloc::boxed::Box;
use core::convert::Infallible;
use mls_rs_core::{
    error::IntoAnyError, extension::ExtensionList, group::Member, identity::SigningIdentity,
};

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CommitDirection {
    Send,
    Receive,
}

/// The source of the commit: either a current member or a new member joining
/// via external commit.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CommitSource {
    ExistingMember(Member),
    NewMember(SigningIdentity),
}

/// Options controlling commit generation
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct CommitOptions {
    pub path_required: bool,
    pub ratchet_tree_extension: bool,
    pub single_welcome_message: bool,
}

impl Default for CommitOptions {
    fn default() -> Self {
        CommitOptions {
            path_required: false,
            ratchet_tree_extension: true,
            single_welcome_message: true,
        }
    }
}

impl CommitOptions {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn with_path_required(self, path_required: bool) -> Self {
        Self {
            path_required,
            ..self
        }
    }

    pub fn with_ratchet_tree_extension(self, ratchet_tree_extension: bool) -> Self {
        Self {
            ratchet_tree_extension,
            ..self
        }
    }

    pub fn with_single_welcome_message(self, single_welcome_message: bool) -> Self {
        Self {
            single_welcome_message,
            ..self
        }
    }
}

/// Options controlling encryption of control and application messages
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct EncryptionOptions {
    #[cfg(feature = "private_message")]
    pub encrypt_control_messages: bool,
    #[cfg(feature = "private_message")]
    pub padding_mode: PaddingMode,
}

#[cfg(feature = "private_message")]
impl EncryptionOptions {
    pub fn new(encrypt_control_messages: bool, padding_mode: PaddingMode) -> Self {
        Self {
            encrypt_control_messages,
            padding_mode,
        }
    }

    pub(crate) fn control_wire_format(&self, sender: Sender) -> WireFormat {
        match sender {
            Sender::Member(_) if self.encrypt_control_messages => WireFormat::PrivateMessage,
            _ => WireFormat::PublicMessage,
        }
    }
}

/// A set of user controlled rules that customize the behavior of MLS.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
pub trait MlsRules: Send + Sync {
    type Error: IntoAnyError;

    /// This is called when preparing or receiving a commit to pre-process the set of committed
    /// proposals.
    ///
    /// Both proposals received during the current epoch and at the time of commit
    /// will be presented for validation and filtering. Filter and validate will
    /// present a raw list of proposals. Standard MLS rules are applied internally
    /// on the result of these rules.
    ///
    /// Each member of a group MUST apply the same proposal rules in order to
    /// maintain a working group.
    ///
    /// Typically, any invalid proposal should result in an error. The exception are invalid
    /// by-reference proposals processed when _preparing_ a commit, which should be filtered
    /// out instead. This is to avoid the deadlock situation when no commit can be generated
    /// after receiving an invalid set of proposal messages.
    ///
    /// `ProposalBundle` can be arbitrarily modified. For example, a Remove proposal that
    /// removes a moderator can result in adding a GroupContextExtensions proposal that updates
    /// the moderator list in the group context. The resulting `ProposalBundle` is validated
    /// by the library.
    async fn filter_proposals(
        &self,
        direction: CommitDirection,
        source: CommitSource,
        current_roster: &Roster,
        extension_list: &ExtensionList,
        proposals: ProposalBundle,
    ) -> Result<ProposalBundle, Self::Error>;

    /// This is called when preparing a commit to determine various options: whether to enforce an update
    /// path in case it is not mandated by MLS, whether to include the ratchet tree in the welcome
    /// message (if the commit adds members) and whether to generate a single welcome message, or one
    /// welcome message for each added member.
    ///
    /// The `new_roster` and `new_extension_list` describe the group state after the commit.
    fn commit_options(
        &self,
        new_roster: &Roster,
        new_extension_list: &ExtensionList,
        proposals: &ProposalBundle,
    ) -> Result<CommitOptions, Self::Error>;

    /// This is called when sending any packet. For proposals and commits, this determines whether to
    /// encrypt them. For any encrypted packet, this determines the padding mode used.
    ///
    /// Note that for commits, the `current_roster` and `current_extension_list` describe the group state
    /// before the commit, unlike in [commit_options](MlsRules::commit_options).
    fn encryption_options(
        &self,
        current_roster: &Roster,
        current_extension_list: &ExtensionList,
    ) -> Result<EncryptionOptions, Self::Error>;
}

macro_rules! delegate_mls_rules {
    ($implementer:ty) => {
        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
        #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
        impl<T: MlsRules + ?Sized> MlsRules for $implementer {
            type Error = T::Error;

            #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
            async fn filter_proposals(
                &self,
                direction: CommitDirection,
                source: CommitSource,
                current_roster: &Roster,
                extension_list: &ExtensionList,
                proposals: ProposalBundle,
            ) -> Result<ProposalBundle, Self::Error> {
                (**self)
                    .filter_proposals(direction, source, current_roster, extension_list, proposals)
                    .await
            }

            fn commit_options(
                &self,
                roster: &Roster,
                extension_list: &ExtensionList,
                proposals: &ProposalBundle,
            ) -> Result<CommitOptions, Self::Error> {
                (**self).commit_options(roster, extension_list, proposals)
            }

            fn encryption_options(
                &self,
                roster: &Roster,
                extension_list: &ExtensionList,
            ) -> Result<EncryptionOptions, Self::Error> {
                (**self).encryption_options(roster, extension_list)
            }
        }
    };
}

delegate_mls_rules!(Box<T>);
delegate_mls_rules!(&T);

#[derive(Clone, Debug, Default)]
#[non_exhaustive]
/// Default MLS rules with pass-through proposal filter and customizable options.
pub struct DefaultMlsRules {
    pub commit_options: CommitOptions,
    pub encryption_options: EncryptionOptions,
}

impl DefaultMlsRules {
    /// Create new MLS rules with default settings: do not enforce path and do
    /// put the ratchet tree in the extension.
    pub fn new() -> Self {
        Default::default()
    }

    /// Set commit options.
    pub fn with_commit_options(self, commit_options: CommitOptions) -> Self {
        Self {
            commit_options,
            encryption_options: self.encryption_options,
        }
    }

    /// Set encryption options.
    pub fn with_encryption_options(self, encryption_options: EncryptionOptions) -> Self {
        Self {
            commit_options: self.commit_options,
            encryption_options,
        }
    }
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
impl MlsRules for DefaultMlsRules {
    type Error = Infallible;

    async fn filter_proposals(
        &self,
        _direction: CommitDirection,
        _source: CommitSource,
        _current_roster: &Roster,
        _extension_list: &ExtensionList,
        proposals: ProposalBundle,
    ) -> Result<ProposalBundle, Self::Error> {
        Ok(proposals)
    }

    fn commit_options(
        &self,
        _: &Roster,
        _: &ExtensionList,
        _: &ProposalBundle,
    ) -> Result<CommitOptions, Self::Error> {
        Ok(self.commit_options)
    }

    fn encryption_options(
        &self,
        _: &Roster,
        _: &ExtensionList,
    ) -> Result<EncryptionOptions, Self::Error> {
        Ok(self.encryption_options)
    }
}