mls_spec/drafts/mls_extensions/
safe_application.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    SensitiveBytes, ToPrefixedLabel,
5    defs::{
6        ProtocolVersion,
7        labels::{PublicKeyEncryptionLabel, SignatureLabel},
8    },
9    key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15    fn component_id() -> ComponentId;
16
17    fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18        PreSharedKeyId {
19            psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20                crate::key_schedule::ApplicationPsk {
21                    component_id: Self::component_id(),
22                    psk_id,
23                },
24            ),
25            psk_nonce,
26        }
27    }
28
29    fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30        Ok(ComponentData {
31            component_id: Self::component_id(),
32            data: self.to_tls_bytes()?,
33        })
34    }
35}
36
37#[derive(
38    Debug,
39    Clone,
40    PartialEq,
41    Eq,
42    tls_codec::TlsSerialize,
43    tls_codec::TlsDeserialize,
44    tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48    #[tls_codec(with = "crate::tlspl::bytes")]
49    label: Vec<u8>,
50    pub component_id: ComponentId,
51    #[tls_codec(with = "crate::tlspl::bytes")]
52    pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
57    pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58        Self {
59            label: PublicKeyEncryptionLabel::SafeApp
60                .to_prefixed_string(ProtocolVersion::default())
61                .into_bytes(),
62            component_id,
63            context,
64        }
65    }
66
67    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
68    pub fn new_for_signature(
69        label: SignatureLabel,
70        component_id: ComponentId,
71        context: Vec<u8>,
72    ) -> (Self, SignatureLabel) {
73        let col = Self {
74            label: label
75                .to_prefixed_string(ProtocolVersion::default())
76                .into_bytes(),
77            component_id,
78            context,
79        };
80
81        (col, SignatureLabel::ComponentOperationLabel)
82    }
83}
84
85#[derive(
86    Debug,
87    Clone,
88    PartialEq,
89    Eq,
90    Hash,
91    tls_codec::TlsSerialize,
92    tls_codec::TlsDeserialize,
93    tls_codec::TlsSize,
94)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96pub struct ComponentData {
97    pub component_id: ComponentId,
98    #[tls_codec(with = "crate::tlspl::bytes")]
99    pub data: Vec<u8>,
100}
101
102impl ComponentData {
103    pub fn as_ref(&self) -> ComponentDataRef<'_> {
104        ComponentDataRef {
105            component_id: &self.component_id,
106            data: &self.data,
107        }
108    }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize))]
113pub struct ComponentDataRef<'a> {
114    pub component_id: &'a ComponentId,
115    #[tls_codec(with = "crate::tlspl::bytes")]
116    pub data: &'a [u8],
117}
118
119/// Utilitary struct that contains a `BTreeMap` in order to preserve ordering and unicity
120///
121/// Also takes extra care to make sure that the `serde` representation when serialized
122/// is equivalent to the TLS-PL version of it
123#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
124#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
125#[cfg_attr(
126    feature = "serde",
127    serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
128)]
129pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
130
131impl ComponentDataMap {
132    fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
133        self.0
134            .get(&C::component_id())
135            .map(|data| C::from_tls_bytes(data))
136            .transpose()
137    }
138
139    fn insert_or_update_component<C: Component>(
140        &mut self,
141        component: &C,
142    ) -> crate::MlsSpecResult<bool> {
143        // This is put before to make sure we don't error out on serialization before modifying the map
144        let component_data = component.to_tls_bytes()?;
145        match self.0.entry(C::component_id()) {
146            std::collections::btree_map::Entry::Vacant(vacant_entry) => {
147                vacant_entry.insert(component_data);
148                Ok(true)
149            }
150            std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
151                *(occupied_entry.get_mut()) = component_data;
152                Ok(false)
153            }
154        }
155    }
156
157    fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
158        self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
159    }
160}
161
162impl tls_codec::Size for ComponentDataMap {
163    fn tls_serialized_len(&self) -> usize {
164        crate::tlspl::tls_serialized_len_as_vlvec(
165            self.iter()
166                .map(|(component_id, data)| {
167                    ComponentDataRef { component_id, data }.tls_serialized_len()
168                })
169                .sum(),
170        )
171    }
172}
173
174impl tls_codec::Deserialize for ComponentDataMap {
175    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
176    where
177        Self: Sized,
178    {
179        let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
180
181        Ok(Self(BTreeMap::from_iter(
182            tlspl_value
183                .into_iter()
184                .map(|cdata| (cdata.component_id, cdata.data)),
185        )))
186    }
187}
188
189impl tls_codec::Serialize for ComponentDataMap {
190    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
191        // TODO: Improve this by not allocating a vec of refs
192        self.iter()
193            .map(|(component_id, data)| ComponentDataRef { component_id, data })
194            .collect::<Vec<ComponentDataRef>>()
195            .tls_serialize(writer)
196    }
197}
198
199impl std::ops::Deref for ComponentDataMap {
200    type Target = BTreeMap<ComponentId, Vec<u8>>;
201    fn deref(&self) -> &Self::Target {
202        &self.0
203    }
204}
205
206impl std::ops::DerefMut for ComponentDataMap {
207    fn deref_mut(&mut self) -> &mut Self::Target {
208        &mut self.0
209    }
210}
211
212impl From<Vec<ComponentData>> for ComponentDataMap {
213    fn from(value: Vec<ComponentData>) -> Self {
214        Self(BTreeMap::from_iter(
215            value
216                .into_iter()
217                .map(|component| (component.component_id, component.data)),
218        ))
219    }
220}
221
222#[allow(clippy::from_over_into)]
223impl Into<Vec<ComponentData>> for ComponentDataMap {
224    fn into(self) -> Vec<ComponentData> {
225        self.0
226            .into_iter()
227            .map(|(component_id, data)| ComponentData { component_id, data })
228            .collect()
229    }
230}
231
232/// Please note that this ApplicationDataDictionary is backed by a `BTreeMap` to
233/// take care of ordering and deduplication automatically.
234///
235/// The conversion from/to a `Vec<ComponentData>` is done at serialization/deserialization time
236#[derive(
237    Debug,
238    Default,
239    Clone,
240    PartialEq,
241    Eq,
242    Hash,
243    tls_codec::TlsSize,
244    tls_codec::TlsDeserialize,
245    tls_codec::TlsSerialize,
246)]
247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
248pub struct ApplicationDataDictionary {
249    pub component_data: ComponentDataMap,
250}
251
252impl ApplicationDataDictionary {
253    pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
254        self.component_data
255            .iter()
256            .map(|(component_id, data)| ComponentDataRef { component_id, data })
257    }
258
259    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
260        self.component_data.extract_component::<C>()
261    }
262
263    /// Returns `true` if newly inserted
264    pub fn insert_or_update_component<C: Component>(
265        &mut self,
266        component: &C,
267    ) -> crate::MlsSpecResult<bool> {
268        self.component_data.insert_or_update_component(component)
269    }
270
271    /// Applies an ApplicationDataUpdate proposal
272    ///
273    /// Returns `false` in only one case: when an `op` is set to `remove` tries to
274    /// remove a non-existing component, which is a soft-error in itself
275    pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
276        match update.op {
277            ApplicationDataUpdateOperation::Update { update: data } => {
278                *self.component_data.entry(update.component_id).or_default() = data;
279                true
280            }
281            ApplicationDataUpdateOperation::Remove => {
282                self.component_data.remove(&update.component_id).is_some()
283            }
284        }
285    }
286}
287
288impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
289    fn from(val: ApplicationDataDictionary) -> Self {
290        crate::group::extensions::Extension::ApplicationData(val)
291    }
292}
293
294#[derive(
295    Debug,
296    Clone,
297    PartialEq,
298    Eq,
299    tls_codec::TlsSerialize,
300    tls_codec::TlsDeserialize,
301    tls_codec::TlsSize,
302)]
303#[repr(u8)]
304#[cfg_attr(
305    feature = "serde",
306    derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
307)]
308pub enum ApplicationDataUpdateOperationType {
309    Invalid = 0x00,
310    Update = 0x01,
311    Remove = 0x02,
312}
313
314#[derive(
315    Debug,
316    Clone,
317    PartialEq,
318    Eq,
319    tls_codec::TlsSerialize,
320    tls_codec::TlsDeserialize,
321    tls_codec::TlsSize,
322)]
323#[repr(u8)]
324#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
325pub enum ApplicationDataUpdateOperation {
326    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
327    Update {
328        #[tls_codec(with = "crate::tlspl::bytes")]
329        update: Vec<u8>,
330    },
331    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
332    Remove,
333}
334
335#[derive(
336    Debug,
337    Clone,
338    PartialEq,
339    Eq,
340    tls_codec::TlsSerialize,
341    tls_codec::TlsDeserialize,
342    tls_codec::TlsSize,
343)]
344#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
345pub struct AppDataUpdate {
346    pub component_id: ComponentId,
347    pub op: ApplicationDataUpdateOperation,
348}
349
350impl AppDataUpdate {
351    /// Allows to extract a concrete `Component` from an update operation
352    ///
353    /// Returns Ok(None) if the update is a `Remove` operation
354    /// Otherwise returns Ok(Some(C)) unless an error occurs
355    pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
356        let type_component_id = C::component_id();
357        if type_component_id != self.component_id {
358            return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
359                expected: type_component_id,
360                actual: self.component_id,
361            });
362        }
363
364        let ApplicationDataUpdateOperation::Update { update } = &self.op else {
365            return Ok(None);
366        };
367
368        Ok(Some(C::from_tls_bytes(update)?))
369    }
370}
371
372#[derive(
373    Debug,
374    Clone,
375    PartialEq,
376    Eq,
377    tls_codec::TlsSerialize,
378    tls_codec::TlsDeserialize,
379    tls_codec::TlsSize,
380)]
381#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
382pub struct ApplicationData {
383    pub component_id: ComponentId,
384    #[tls_codec(with = "crate::tlspl::bytes")]
385    pub data: Vec<u8>,
386}
387
388pub type AppEphemeral = ApplicationData;
389
390#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
391#[cfg_attr(feature = "serde", derive(serde::Serialize))]
392#[cfg_attr(feature = "serde", serde(transparent))]
393pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
394
395impl<'a> SafeAadItemRef<'a> {
396    pub fn component_id(&self) -> &ComponentId {
397        self.0.component_id
398    }
399
400    pub fn aad_item_data(&self) -> &[u8] {
401        self.0.data
402    }
403
404    pub fn from_item_data<C: Component>(
405        component_id: &'a ComponentId,
406        aad_item_data: &'a [u8],
407    ) -> Option<Self> {
408        (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
409            component_id,
410            data: aad_item_data,
411        }))
412    }
413}
414
415#[derive(
416    Debug,
417    Clone,
418    PartialEq,
419    Eq,
420    Hash,
421    tls_codec::TlsSerialize,
422    tls_codec::TlsDeserialize,
423    tls_codec::TlsSize,
424)]
425#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
426#[cfg_attr(feature = "serde", serde(transparent))]
427pub struct SafeAadItem(ComponentData);
428
429impl SafeAadItem {
430    pub fn as_ref(&self) -> SafeAadItemRef<'_> {
431        SafeAadItemRef(self.0.as_ref())
432    }
433}
434
435#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
436#[cfg_attr(feature = "serde", derive(serde::Serialize))]
437pub struct SafeAadRef<'a> {
438    pub aad_items: &'a [&'a SafeAadItemRef<'a>],
439}
440
441impl SafeAadRef<'_> {
442    pub fn is_ordered_and_unique(&self) -> bool {
443        let mut iter = self.aad_items.iter().peekable();
444
445        while let Some(item) = iter.next() {
446            let Some(next) = iter.peek() else {
447                continue;
448            };
449
450            if item.component_id() >= next.component_id() {
451                return false;
452            }
453        }
454
455        true
456    }
457}
458
459impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
460    fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
461        Self { aad_items }
462    }
463}
464
465#[derive(
466    Debug,
467    Default,
468    Clone,
469    PartialEq,
470    Eq,
471    Hash,
472    tls_codec::TlsSerialize,
473    tls_codec::TlsDeserialize,
474    tls_codec::TlsSize,
475)]
476#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
477pub struct SafeAad {
478    aad_items: ComponentDataMap,
479}
480
481impl SafeAad {
482    pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
483        self.aad_items
484            .iter()
485            .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
486    }
487
488    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
489        self.aad_items.extract_component::<C>()
490    }
491
492    /// Returns `true` if newly inserted
493    pub fn insert_or_update_component<C: Component>(
494        &mut self,
495        component: &C,
496    ) -> crate::MlsSpecResult<bool> {
497        self.aad_items.insert_or_update_component(component)
498    }
499}
500
501#[derive(
502    Debug,
503    Clone,
504    PartialEq,
505    Eq,
506    Hash,
507    tls_codec::TlsSerialize,
508    tls_codec::TlsDeserialize,
509    tls_codec::TlsSize,
510)]
511#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
512pub struct WireFormats {
513    pub wire_formats: Vec<crate::defs::WireFormat>,
514}
515
516#[derive(
517    Debug,
518    Clone,
519    PartialEq,
520    Eq,
521    tls_codec::TlsSerialize,
522    tls_codec::TlsDeserialize,
523    tls_codec::TlsSize,
524)]
525#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
526pub struct ComponentsList {
527    pub component_ids: Vec<ComponentId>,
528}
529
530#[derive(
531    Debug,
532    Clone,
533    PartialEq,
534    Eq,
535    tls_codec::TlsSerialize,
536    tls_codec::TlsDeserialize,
537    tls_codec::TlsSize,
538)]
539pub struct AppComponents(pub ComponentsList);
540
541impl Component for AppComponents {
542    fn component_id() -> ComponentId {
543        super::APP_COMPONENTS_ID
544    }
545}
546
547#[derive(
548    Debug,
549    Clone,
550    PartialEq,
551    Eq,
552    tls_codec::TlsSerialize,
553    tls_codec::TlsDeserialize,
554    tls_codec::TlsSize,
555)]
556pub struct SafeAadComponent(pub ComponentsList);
557
558impl Component for SafeAadComponent {
559    fn component_id() -> ComponentId {
560        super::SAFE_AAD_ID
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use std::collections::BTreeMap;
567
568    use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
569    use crate::{
570        drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
571        generate_roundtrip_test,
572    };
573
574    generate_roundtrip_test!(can_roundtrip_appdatadict, {
575        ApplicationDataDictionary {
576            component_data: super::ComponentDataMap(BTreeMap::from([
577                (1, vec![1]),
578                (3, vec![3]),
579                (2, vec![2]),
580            ])),
581        }
582    });
583
584    generate_roundtrip_test!(can_roundtrip_safeaad, {
585        SafeAad {
586            aad_items: super::ComponentDataMap(BTreeMap::from([
587                (1, vec![1]),
588                (3, vec![3]),
589                (2, vec![2]),
590            ])),
591        }
592    });
593
594    #[test]
595    fn can_build_safe_aad() {
596        let mut safe_aad = SafeAad::default();
597        safe_aad
598            .insert_or_update_component(&LastResortKeyPackage)
599            .unwrap();
600
601        let cid = LastResortKeyPackage::component_id();
602        let aad_item_ref =
603            SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
604
605        let items = &[&aad_item_ref];
606        let safe_ref = SafeAadRef::from(items.as_slice());
607        assert!(safe_ref.is_ordered_and_unique());
608    }
609}