mls_spec/
key_schedule.rs

1use crate::{
2    SensitiveBytes,
3    defs::{CiphersuiteId, Epoch, ProtocolVersion, WireFormat, labels::KdfLabelKind},
4    group::{ExternalSender, GroupId, RequiredCapabilities, extensions::Extension},
5    messages::FramedContent,
6    tree::TreeHash,
7};
8
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    Default,
15    tls_codec::TlsSerialize,
16    tls_codec::TlsDeserialize,
17    tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct GroupContext {
21    pub version: ProtocolVersion,
22    pub cipher_suite: CiphersuiteId,
23    #[tls_codec(with = "crate::tlspl::bytes")]
24    group_id: GroupId,
25    pub epoch: u64,
26    pub tree_hash: TreeHash,
27    pub confirmed_transcript_hash: TranscriptHash,
28    pub extensions: Vec<Extension>,
29}
30
31impl GroupContext {
32    /// Allows for initialization with an arbitrary group id
33    pub fn with_group_id(group_id: GroupId) -> Self {
34        Self {
35            group_id,
36            ..Default::default()
37        }
38    }
39
40    // 8.1 -> The `group_id` field is constant
41    pub fn group_id(&self) -> &[u8] {
42        &self.group_id
43    }
44
45    pub fn external_senders(&self) -> &[ExternalSender] {
46        self.extensions
47            .iter()
48            .find_map(|ext| {
49                if let Extension::ExternalSenders(ext_senders) = ext {
50                    Some(ext_senders.as_slice())
51                } else {
52                    None
53                }
54            })
55            .unwrap_or_default()
56    }
57
58    pub fn required_capabilities(&self) -> Option<&RequiredCapabilities> {
59        self.extensions.iter().find_map(|ext| {
60            if let Extension::RequiredCapabilities(required_caps) = ext {
61                Some(required_caps)
62            } else {
63                None
64            }
65        })
66    }
67}
68
69#[derive(Debug, Copy, Clone, PartialEq, Eq)]
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71#[repr(u8)]
72pub enum EpochSecretExport {
73    SenderDataSecret,
74    EncryptionSecret,
75    ExporterSecret,
76    ExternalSecret,
77    ConfirmationKey,
78    MembershipKey,
79    ResumptionPsk,
80    EpochAuthenticator,
81    #[cfg(feature = "draft-kohbrok-mls-associated-parties")]
82    AssociatedPartiesSecret,
83}
84
85impl From<EpochSecretExport> for KdfLabelKind {
86    fn from(value: EpochSecretExport) -> Self {
87        match value {
88            EpochSecretExport::SenderDataSecret => KdfLabelKind::SenderData,
89            EpochSecretExport::EncryptionSecret => KdfLabelKind::Encryption,
90            EpochSecretExport::ExporterSecret => KdfLabelKind::Exporter,
91            EpochSecretExport::ExternalSecret => KdfLabelKind::External,
92            EpochSecretExport::ConfirmationKey => KdfLabelKind::Confirm,
93            EpochSecretExport::MembershipKey => KdfLabelKind::Membership,
94            EpochSecretExport::ResumptionPsk => KdfLabelKind::Resumption,
95            EpochSecretExport::EpochAuthenticator => KdfLabelKind::Authentication,
96            #[cfg(feature = "draft-kohbrok-mls-associated-parties")]
97            EpochSecretExport::AssociatedPartiesSecret => KdfLabelKind::AssociatedPartyEpochSecret,
98        }
99    }
100}
101
102pub type TranscriptHash = SensitiveBytes;
103
104#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
105#[cfg_attr(feature = "serde", derive(serde::Serialize))]
106pub struct ConfirmedTranscriptHashInput<'a> {
107    pub wire_format: &'a WireFormat,
108    pub content: &'a FramedContent,
109    pub signature: &'a [u8],
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize))]
114pub struct InterimTranscriptHashInput<'a> {
115    pub confirmation_tag: &'a [u8],
116}
117
118impl<'a> From<&'a [u8]> for InterimTranscriptHashInput<'a> {
119    fn from(confirmation_tag: &'a [u8]) -> Self {
120        Self { confirmation_tag }
121    }
122}
123
124#[derive(
125    Debug,
126    Copy,
127    Clone,
128    PartialEq,
129    Eq,
130    tls_codec::TlsSerialize,
131    tls_codec::TlsDeserialize,
132    tls_codec::TlsSize,
133)]
134#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
135#[repr(u8)]
136pub enum PskType {
137    Reserved = 0x00,
138    External = 0x01,
139    Resumption = 0x02,
140    #[cfg(feature = "draft-ietf-mls-extensions")]
141    Application = 0x03,
142}
143
144#[derive(
145    Debug,
146    Copy,
147    Clone,
148    PartialEq,
149    Eq,
150    Hash,
151    tls_codec::TlsSerialize,
152    tls_codec::TlsDeserialize,
153    tls_codec::TlsSize,
154)]
155#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
156#[repr(u8)]
157pub enum ResumptionPskUsage {
158    Reserved = 0x00,
159    Application = 0x01,
160    ReInit = 0x02,
161    Branch = 0x03,
162}
163
164#[derive(
165    Debug,
166    Clone,
167    PartialEq,
168    Eq,
169    Hash,
170    tls_codec::TlsSerialize,
171    tls_codec::TlsDeserialize,
172    tls_codec::TlsSize,
173    zeroize::Zeroize,
174    zeroize::ZeroizeOnDrop,
175)]
176#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
177#[repr(u8)]
178pub enum PreSharedKeyIdPskType {
179    #[tls_codec(discriminant = "PskType::External")]
180    External(ExternalPsk),
181    #[tls_codec(discriminant = "PskType::Resumption")]
182    Resumption(ResumptionPsk),
183    #[cfg(feature = "draft-ietf-mls-extensions")]
184    #[tls_codec(discriminant = "PskType::Application")]
185    Application(ApplicationPsk),
186}
187
188#[derive(
189    Debug,
190    Clone,
191    PartialEq,
192    Eq,
193    Hash,
194    tls_codec::TlsSerialize,
195    tls_codec::TlsDeserialize,
196    tls_codec::TlsSize,
197    zeroize::Zeroize,
198    zeroize::ZeroizeOnDrop,
199)]
200#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
201pub struct ExternalPsk {
202    #[tls_codec(with = "crate::tlspl::bytes")]
203    pub psk_id: Vec<u8>,
204}
205
206#[derive(
207    Debug,
208    Clone,
209    PartialEq,
210    Eq,
211    Hash,
212    tls_codec::TlsSerialize,
213    tls_codec::TlsDeserialize,
214    tls_codec::TlsSize,
215    zeroize::Zeroize,
216    zeroize::ZeroizeOnDrop,
217)]
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
219pub struct ResumptionPsk {
220    #[zeroize(skip)]
221    pub usage: ResumptionPskUsage,
222    #[tls_codec(with = "crate::tlspl::bytes")]
223    pub psk_group_id: Vec<u8>,
224    pub psk_epoch: Epoch,
225}
226
227#[cfg(feature = "draft-ietf-mls-extensions")]
228#[derive(
229    Debug,
230    Clone,
231    PartialEq,
232    Eq,
233    Hash,
234    tls_codec::TlsSerialize,
235    tls_codec::TlsDeserialize,
236    tls_codec::TlsSize,
237    zeroize::Zeroize,
238    zeroize::ZeroizeOnDrop,
239)]
240#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
241pub struct ApplicationPsk {
242    #[zeroize(skip)]
243    pub component_id: crate::drafts::mls_extensions::safe_application::ComponentId,
244    #[tls_codec(with = "crate::tlspl::bytes")]
245    pub psk_id: Vec<u8>,
246}
247
248#[derive(
249    Debug,
250    Clone,
251    PartialEq,
252    Eq,
253    Hash,
254    tls_codec::TlsSerialize,
255    tls_codec::TlsDeserialize,
256    tls_codec::TlsSize,
257    zeroize::Zeroize,
258    zeroize::ZeroizeOnDrop,
259)]
260#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
261pub struct PreSharedKeyId {
262    pub psktype: PreSharedKeyIdPskType,
263    pub psk_nonce: SensitiveBytes,
264}
265
266impl PreSharedKeyId {
267    pub fn with_default_nonce(&self) -> Self {
268        Self {
269            psktype: self.psktype.clone(),
270            psk_nonce: SensitiveBytes::default(),
271        }
272    }
273}
274
275#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
276#[cfg_attr(feature = "serde", derive(serde::Serialize))]
277pub struct PskLabel<'a> {
278    pub id: &'a PreSharedKeyId,
279    pub index: u16,
280    pub count: u16,
281}