mls_spec/drafts/mls_extensions/
safe_application.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    SensitiveBytes, ToPrefixedLabel,
5    defs::{
6        ProtocolVersion,
7        labels::{PublicKeyEncryptionLabel, SignatureLabel},
8    },
9    key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15    fn component_id() -> ComponentId;
16
17    fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18        PreSharedKeyId {
19            psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20                crate::key_schedule::ApplicationPsk {
21                    component_id: Self::component_id(),
22                    psk_id,
23                },
24            ),
25            psk_nonce,
26        }
27    }
28
29    fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30        Ok(ComponentData {
31            component_id: Self::component_id(),
32            data: self.to_tls_bytes()?,
33        })
34    }
35}
36
37#[derive(
38    Debug,
39    Clone,
40    PartialEq,
41    Eq,
42    tls_codec::TlsSerialize,
43    tls_codec::TlsDeserialize,
44    tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48    #[tls_codec(with = "crate::tlspl::bytes")]
49    label: Vec<u8>,
50    pub component_id: ComponentId,
51    #[tls_codec(with = "crate::tlspl::bytes")]
52    pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
57    pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58        Self {
59            label: PublicKeyEncryptionLabel::SafeApp
60                .to_prefixed_string(ProtocolVersion::default())
61                .into_bytes(),
62            component_id,
63            context,
64        }
65    }
66
67    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
68    pub fn new_for_signature(
69        label: SignatureLabel,
70        component_id: ComponentId,
71        context: Vec<u8>,
72    ) -> (Self, SignatureLabel) {
73        let col = Self {
74            label: label
75                .to_prefixed_string(ProtocolVersion::default())
76                .into_bytes(),
77            component_id,
78            context,
79        };
80
81        (col, SignatureLabel::ComponentOperationLabel)
82    }
83}
84
85#[derive(
86    Debug,
87    Clone,
88    PartialEq,
89    Eq,
90    tls_codec::TlsSerialize,
91    tls_codec::TlsDeserialize,
92    tls_codec::TlsSize,
93)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct ComponentData {
96    pub component_id: ComponentId,
97    #[tls_codec(with = "crate::tlspl::bytes")]
98    pub data: Vec<u8>,
99}
100
101impl ComponentData {
102    pub fn as_ref(&self) -> ComponentDataRef {
103        ComponentDataRef {
104            component_id: &self.component_id,
105            data: &self.data,
106        }
107    }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize))]
112pub struct ComponentDataRef<'a> {
113    pub component_id: &'a ComponentId,
114    #[tls_codec(with = "crate::tlspl::bytes")]
115    pub data: &'a [u8],
116}
117
118/// Utilitary struct that contains a `BTreeMap` in order to preserve ordering and unicity
119///
120/// Also takes extra care to make sure that the `serde` representation when serialized
121/// is equivalent to the TLS-PL version of it
122#[derive(Debug, Default, Clone, PartialEq, Eq)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[cfg_attr(
125    feature = "serde",
126    serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
127)]
128pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
129
130impl ComponentDataMap {
131    fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
132        self.0
133            .get(&C::component_id())
134            .map(|data| C::from_tls_bytes(data))
135            .transpose()
136    }
137
138    fn insert_or_update_component<C: Component>(
139        &mut self,
140        component: &C,
141    ) -> crate::MlsSpecResult<bool> {
142        // This is put before to make sure we don't error out on serialization before modifying the map
143        let component_data = component.to_tls_bytes()?;
144        match self.0.entry(C::component_id()) {
145            std::collections::btree_map::Entry::Vacant(vacant_entry) => {
146                vacant_entry.insert(component_data);
147                Ok(true)
148            }
149            std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
150                *(occupied_entry.get_mut()) = component_data;
151                Ok(false)
152            }
153        }
154    }
155
156    fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
157        self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
158    }
159}
160
161impl tls_codec::Size for ComponentDataMap {
162    fn tls_serialized_len(&self) -> usize {
163        crate::tlspl::tls_serialized_len_as_vlvec(
164            self.iter()
165                .map(|(component_id, data)| {
166                    ComponentDataRef { component_id, data }.tls_serialized_len()
167                })
168                .sum(),
169        )
170    }
171}
172
173impl tls_codec::Deserialize for ComponentDataMap {
174    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
175    where
176        Self: Sized,
177    {
178        let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
179
180        Ok(Self(BTreeMap::from_iter(
181            tlspl_value
182                .into_iter()
183                .map(|cdata| (cdata.component_id, cdata.data)),
184        )))
185    }
186}
187
188impl tls_codec::Serialize for ComponentDataMap {
189    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
190        // TODO: Improve this by not allocating a vec of refs
191        self.iter()
192            .map(|(component_id, data)| ComponentDataRef { component_id, data })
193            .collect::<Vec<ComponentDataRef>>()
194            .tls_serialize(writer)
195    }
196}
197
198impl std::ops::Deref for ComponentDataMap {
199    type Target = BTreeMap<ComponentId, Vec<u8>>;
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205impl std::ops::DerefMut for ComponentDataMap {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.0
208    }
209}
210
211impl From<Vec<ComponentData>> for ComponentDataMap {
212    fn from(value: Vec<ComponentData>) -> Self {
213        Self(BTreeMap::from_iter(
214            value
215                .into_iter()
216                .map(|component| (component.component_id, component.data)),
217        ))
218    }
219}
220
221#[allow(clippy::from_over_into)]
222impl Into<Vec<ComponentData>> for ComponentDataMap {
223    fn into(self) -> Vec<ComponentData> {
224        self.0
225            .into_iter()
226            .map(|(component_id, data)| ComponentData { component_id, data })
227            .collect()
228    }
229}
230
231/// Please note that this ApplicationDataDictionary is backed by a `BTreeMap` to
232/// take care of ordering and deduplication automatically.
233///
234/// The conversion from/to a `Vec<ComponentData>` is done at serialization/deserialization time
235#[derive(
236    Debug,
237    Default,
238    Clone,
239    PartialEq,
240    Eq,
241    tls_codec::TlsSize,
242    tls_codec::TlsDeserialize,
243    tls_codec::TlsSerialize,
244)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub struct ApplicationDataDictionary {
247    pub component_data: ComponentDataMap,
248}
249
250impl ApplicationDataDictionary {
251    pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef> {
252        self.component_data
253            .iter()
254            .map(|(component_id, data)| ComponentDataRef { component_id, data })
255    }
256
257    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
258        self.component_data.extract_component::<C>()
259    }
260
261    /// Returns `true` if newly inserted
262    pub fn insert_or_update_component<C: Component>(
263        &mut self,
264        component: &C,
265    ) -> crate::MlsSpecResult<bool> {
266        self.component_data.insert_or_update_component(component)
267    }
268
269    /// Applies an ApplicationDataUpdate proposal
270    ///
271    /// Returns `false` in only one case: when an `op` is set to `remove` tries to
272    /// remove a non-existing component, which is a soft-error in itself
273    pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
274        match update.op {
275            ApplicationDataUpdateOperation::Update { update: data } => {
276                *self.component_data.entry(update.component_id).or_default() = data;
277                true
278            }
279            ApplicationDataUpdateOperation::Remove => {
280                self.component_data.remove(&update.component_id).is_some()
281            }
282        }
283    }
284}
285
286impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
287    fn from(val: ApplicationDataDictionary) -> Self {
288        crate::group::extensions::Extension::ApplicationData(val)
289    }
290}
291
292#[derive(
293    Debug,
294    Clone,
295    PartialEq,
296    Eq,
297    tls_codec::TlsSerialize,
298    tls_codec::TlsDeserialize,
299    tls_codec::TlsSize,
300)]
301#[repr(u8)]
302#[cfg_attr(
303    feature = "serde",
304    derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
305)]
306pub enum ApplicationDataUpdateOperationType {
307    Invalid = 0x00,
308    Update = 0x01,
309    Remove = 0x02,
310}
311
312#[derive(
313    Debug,
314    Clone,
315    PartialEq,
316    Eq,
317    tls_codec::TlsSerialize,
318    tls_codec::TlsDeserialize,
319    tls_codec::TlsSize,
320)]
321#[repr(u8)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323pub enum ApplicationDataUpdateOperation {
324    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
325    Update {
326        #[tls_codec(with = "crate::tlspl::bytes")]
327        update: Vec<u8>,
328    },
329    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
330    Remove,
331}
332
333#[derive(
334    Debug,
335    Clone,
336    PartialEq,
337    Eq,
338    tls_codec::TlsSerialize,
339    tls_codec::TlsDeserialize,
340    tls_codec::TlsSize,
341)]
342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
343pub struct AppDataUpdate {
344    pub component_id: ComponentId,
345    pub op: ApplicationDataUpdateOperation,
346}
347
348impl AppDataUpdate {
349    /// Allows to extract a concrete `Component` from an update operation
350    ///
351    /// Returns Ok(None) if the update is a `Remove` operation
352    /// Otherwise returns Ok(Some(C)) unless an error occurs
353    pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
354        let type_component_id = C::component_id();
355        if type_component_id != self.component_id {
356            return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
357                expected: type_component_id,
358                actual: self.component_id,
359            });
360        }
361
362        let ApplicationDataUpdateOperation::Update { update } = &self.op else {
363            return Ok(None);
364        };
365
366        Ok(Some(C::from_tls_bytes(update)?))
367    }
368}
369
370#[derive(
371    Debug,
372    Clone,
373    PartialEq,
374    Eq,
375    tls_codec::TlsSerialize,
376    tls_codec::TlsDeserialize,
377    tls_codec::TlsSize,
378)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct ApplicationData {
381    pub component_id: ComponentId,
382    #[tls_codec(with = "crate::tlspl::bytes")]
383    pub data: Vec<u8>,
384}
385
386pub type AppEphemeral = ApplicationData;
387
388#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize))]
390#[cfg_attr(feature = "serde", serde(transparent))]
391pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
392
393impl<'a> SafeAadItemRef<'a> {
394    pub fn component_id(&self) -> &ComponentId {
395        self.0.component_id
396    }
397
398    pub fn aad_item_data(&self) -> &[u8] {
399        self.0.data
400    }
401
402    pub fn from_item_data<C: Component>(
403        component_id: &'a ComponentId,
404        aad_item_data: &'a [u8],
405    ) -> Option<Self> {
406        (&C::component_id() == component_id).then(|| {
407            SafeAadItemRef(ComponentDataRef {
408                component_id,
409                data: aad_item_data,
410            })
411        })
412    }
413}
414
415#[derive(
416    Debug,
417    Clone,
418    PartialEq,
419    Eq,
420    tls_codec::TlsSerialize,
421    tls_codec::TlsDeserialize,
422    tls_codec::TlsSize,
423)]
424#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
425#[cfg_attr(feature = "serde", serde(transparent))]
426pub struct SafeAadItem(ComponentData);
427
428impl SafeAadItem {
429    pub fn as_ref(&self) -> SafeAadItemRef {
430        SafeAadItemRef(self.0.as_ref())
431    }
432}
433
434#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
435#[cfg_attr(feature = "serde", derive(serde::Serialize))]
436pub struct SafeAadRef<'a> {
437    pub aad_items: &'a [&'a SafeAadItemRef<'a>],
438}
439
440impl SafeAadRef<'_> {
441    pub fn is_ordered_and_unique(&self) -> bool {
442        let mut iter = self.aad_items.iter().peekable();
443
444        while let Some(item) = iter.next() {
445            let Some(next) = iter.peek() else {
446                continue;
447            };
448
449            if item.component_id() >= next.component_id() {
450                return false;
451            }
452        }
453
454        true
455    }
456}
457
458impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
459    fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
460        Self { aad_items }
461    }
462}
463
464#[derive(
465    Debug,
466    Default,
467    Clone,
468    PartialEq,
469    Eq,
470    tls_codec::TlsSerialize,
471    tls_codec::TlsDeserialize,
472    tls_codec::TlsSize,
473)]
474#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
475pub struct SafeAad {
476    aad_items: ComponentDataMap,
477}
478
479impl SafeAad {
480    pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef> {
481        self.aad_items
482            .iter()
483            .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
484    }
485
486    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
487        self.aad_items.extract_component::<C>()
488    }
489
490    /// Returns `true` if newly inserted
491    pub fn insert_or_update_component<C: Component>(
492        &mut self,
493        component: &C,
494    ) -> crate::MlsSpecResult<bool> {
495        self.aad_items.insert_or_update_component(component)
496    }
497}
498
499#[derive(
500    Debug,
501    Clone,
502    PartialEq,
503    Eq,
504    tls_codec::TlsSerialize,
505    tls_codec::TlsDeserialize,
506    tls_codec::TlsSize,
507)]
508#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
509pub struct WireFormats {
510    pub wire_formats: Vec<crate::defs::WireFormat>,
511}
512
513#[derive(
514    Debug,
515    Clone,
516    PartialEq,
517    Eq,
518    tls_codec::TlsSerialize,
519    tls_codec::TlsDeserialize,
520    tls_codec::TlsSize,
521)]
522#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
523pub struct ComponentsList {
524    pub component_ids: Vec<ComponentId>,
525}
526
527#[derive(
528    Debug,
529    Clone,
530    PartialEq,
531    Eq,
532    tls_codec::TlsSerialize,
533    tls_codec::TlsDeserialize,
534    tls_codec::TlsSize,
535)]
536pub struct AppComponents(pub ComponentsList);
537
538impl Component for AppComponents {
539    fn component_id() -> ComponentId {
540        super::APP_COMPONENTS_ID
541    }
542}
543
544#[derive(
545    Debug,
546    Clone,
547    PartialEq,
548    Eq,
549    tls_codec::TlsSerialize,
550    tls_codec::TlsDeserialize,
551    tls_codec::TlsSize,
552)]
553pub struct SafeAadComponent(pub ComponentsList);
554
555impl Component for SafeAadComponent {
556    fn component_id() -> ComponentId {
557        super::SAFE_AAD_ID
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use std::collections::BTreeMap;
564
565    use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
566    use crate::{
567        drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
568        generate_roundtrip_test,
569    };
570
571    generate_roundtrip_test!(can_roundtrip_appdatadict, {
572        ApplicationDataDictionary {
573            component_data: super::ComponentDataMap(BTreeMap::from([
574                (1, vec![1]),
575                (3, vec![3]),
576                (2, vec![2]),
577            ])),
578        }
579    });
580
581    generate_roundtrip_test!(can_roundtrip_safeaad, {
582        SafeAad {
583            aad_items: super::ComponentDataMap(BTreeMap::from([
584                (1, vec![1]),
585                (3, vec![3]),
586                (2, vec![2]),
587            ])),
588        }
589    });
590
591    #[test]
592    fn can_build_safe_aad() {
593        let mut safe_aad = SafeAad::default();
594        safe_aad
595            .insert_or_update_component(&LastResortKeyPackage)
596            .unwrap();
597
598        let cid = LastResortKeyPackage::component_id();
599        let aad_item_ref =
600            SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
601
602        let items = &[&aad_item_ref];
603        let safe_ref = SafeAadRef::from(items.as_slice());
604        assert!(safe_ref.is_ordered_and_unique());
605    }
606}