Skip to main content

mls_spec/drafts/mls_extensions/
safe_application.rs

1use std::collections::BTreeMap;
2
3use crate::{SensitiveBytes, key_schedule::PreSharedKeyId};
4
5pub type ComponentId = u32;
6
7pub const COMPONENT_ID_GREASE_VALUES: [ComponentId; 15] = [
8    0x0000_0A0A,
9    0x0000_1A1A,
10    0x0000_2A2A,
11    0x0000_3A3A,
12    0x0000_4A4A,
13    0x0000_5A5A,
14    0x0000_6A6A,
15    0x0000_7A7A,
16    0x0000_8A8A,
17    0x0000_9A9A,
18    0x0000_AAAA,
19    0x0000_BABA,
20    0x0000_CACA,
21    0x0000_DADA,
22    0x0000_EAEA,
23];
24
25pub trait Component: crate::Parsable + crate::Serializable {
26    fn component_id() -> ComponentId;
27
28    fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
29        PreSharedKeyId {
30            psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
31                crate::key_schedule::ApplicationPsk {
32                    component_id: Self::component_id(),
33                    psk_id,
34                },
35            ),
36            psk_nonce,
37        }
38    }
39
40    fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
41        Ok(ComponentData {
42            component_id: Self::component_id(),
43            data: self.to_tls_bytes()?,
44        })
45    }
46}
47
48#[derive(
49    Debug,
50    Clone,
51    Copy,
52    Default,
53    PartialEq,
54    Eq,
55    strum::IntoStaticStr,
56    strum::EnumString,
57    strum::Display,
58)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[repr(u8)]
61pub enum ComponentOperationBaseLabel {
62    #[default]
63    Application = 0x00,
64}
65
66impl tls_codec::Size for ComponentOperationBaseLabel {
67    fn tls_serialized_len(&self) -> usize {
68        crate::tlspl::string::tls_serialized_len(self.into())
69    }
70}
71
72impl tls_codec::Serialize for ComponentOperationBaseLabel {
73    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
74        crate::tlspl::string::tls_serialize(self.into(), writer)
75    }
76}
77
78impl tls_codec::Deserialize for ComponentOperationBaseLabel {
79    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
80    where
81        Self: Sized,
82    {
83        <Self as std::str::FromStr>::from_str(&crate::tlspl::string::tls_deserialize(bytes)?)
84            .map_err(|_| {
85                tls_codec::Error::DecodingError(
86                    "Unknown Value in ComponentOperationBaseLabel".into(),
87                )
88            })
89    }
90}
91
92#[derive(
93    Debug,
94    Clone,
95    PartialEq,
96    Eq,
97    tls_codec::TlsSerialize,
98    tls_codec::TlsDeserialize,
99    tls_codec::TlsSize,
100)]
101#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
102pub struct ComponentOperationLabel {
103    pub base_label: ComponentOperationBaseLabel,
104    pub component_id: ComponentId,
105    #[tls_codec(with = "crate::tlspl::bytes")]
106    pub label: Vec<u8>,
107}
108
109#[derive(
110    Debug,
111    Clone,
112    PartialEq,
113    Eq,
114    Hash,
115    tls_codec::TlsSerialize,
116    tls_codec::TlsDeserialize,
117    tls_codec::TlsSize,
118)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub struct ComponentData {
121    pub component_id: ComponentId,
122    #[tls_codec(with = "crate::tlspl::bytes")]
123    pub data: Vec<u8>,
124}
125
126impl ComponentData {
127    pub fn as_ref(&self) -> ComponentDataRef<'_> {
128        ComponentDataRef {
129            component_id: &self.component_id,
130            data: &self.data,
131        }
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
136#[cfg_attr(feature = "serde", derive(serde::Serialize))]
137pub struct ComponentDataRef<'a> {
138    pub component_id: &'a ComponentId,
139    #[tls_codec(with = "crate::tlspl::bytes")]
140    pub data: &'a [u8],
141}
142
143/// Utilitary struct that contains a `BTreeMap` in order to preserve ordering and unicity
144///
145/// Also takes extra care to make sure that the `serde` representation when serialized
146/// is equivalent to the TLS-PL version of it
147#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[cfg_attr(
150    feature = "serde",
151    serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
152)]
153pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
154
155impl ComponentDataMap {
156    fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
157        self.0
158            .get(&C::component_id())
159            .map(|data| C::from_tls_bytes(data))
160            .transpose()
161    }
162
163    fn insert_or_update_component<C: Component>(
164        &mut self,
165        component: &C,
166    ) -> crate::MlsSpecResult<bool> {
167        // This is put before to make sure we don't error out on serialization before modifying the map
168        let component_data = component.to_tls_bytes()?;
169        match self.0.entry(C::component_id()) {
170            std::collections::btree_map::Entry::Vacant(vacant_entry) => {
171                vacant_entry.insert(component_data);
172                Ok(true)
173            }
174            std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
175                *(occupied_entry.get_mut()) = component_data;
176                Ok(false)
177            }
178        }
179    }
180
181    fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
182        self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
183    }
184}
185
186impl tls_codec::Size for ComponentDataMap {
187    fn tls_serialized_len(&self) -> usize {
188        crate::tlspl::tls_serialized_len_as_vlvec(
189            self.iter()
190                .map(|(component_id, data)| {
191                    ComponentDataRef { component_id, data }.tls_serialized_len()
192                })
193                .sum(),
194        )
195    }
196}
197
198impl tls_codec::Deserialize for ComponentDataMap {
199    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
200    where
201        Self: Sized,
202    {
203        let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
204
205        Ok(Self(BTreeMap::from_iter(
206            tlspl_value
207                .into_iter()
208                .map(|cdata| (cdata.component_id, cdata.data)),
209        )))
210    }
211}
212
213impl tls_codec::Serialize for ComponentDataMap {
214    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
215        // TODO: Improve this by not allocating a vec of refs
216        self.iter()
217            .map(|(component_id, data)| ComponentDataRef { component_id, data })
218            .collect::<Vec<ComponentDataRef>>()
219            .tls_serialize(writer)
220    }
221}
222
223impl std::ops::Deref for ComponentDataMap {
224    type Target = BTreeMap<ComponentId, Vec<u8>>;
225    fn deref(&self) -> &Self::Target {
226        &self.0
227    }
228}
229
230impl std::ops::DerefMut for ComponentDataMap {
231    fn deref_mut(&mut self) -> &mut Self::Target {
232        &mut self.0
233    }
234}
235
236impl From<Vec<ComponentData>> for ComponentDataMap {
237    fn from(value: Vec<ComponentData>) -> Self {
238        Self(BTreeMap::from_iter(
239            value
240                .into_iter()
241                .map(|component| (component.component_id, component.data)),
242        ))
243    }
244}
245
246#[allow(clippy::from_over_into)]
247impl Into<Vec<ComponentData>> for ComponentDataMap {
248    fn into(self) -> Vec<ComponentData> {
249        self.0
250            .into_iter()
251            .map(|(component_id, data)| ComponentData { component_id, data })
252            .collect()
253    }
254}
255
256/// Please note that this ApplicationDataDictionary is backed by a `BTreeMap` to
257/// take care of ordering and deduplication automatically.
258///
259/// The conversion from/to a `Vec<ComponentData>` is done at serialization/deserialization time
260#[derive(
261    Debug,
262    Default,
263    Clone,
264    PartialEq,
265    Eq,
266    Hash,
267    tls_codec::TlsSize,
268    tls_codec::TlsDeserialize,
269    tls_codec::TlsSerialize,
270)]
271#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
272pub struct ApplicationDataDictionary {
273    pub component_data: ComponentDataMap,
274}
275
276impl ApplicationDataDictionary {
277    pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
278        self.component_data
279            .iter()
280            .map(|(component_id, data)| ComponentDataRef { component_id, data })
281    }
282
283    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
284        self.component_data.extract_component::<C>()
285    }
286
287    /// Returns `true` if newly inserted
288    pub fn insert_or_update_component<C: Component>(
289        &mut self,
290        component: &C,
291    ) -> crate::MlsSpecResult<bool> {
292        self.component_data.insert_or_update_component(component)
293    }
294
295    /// Applies an ApplicationDataUpdate proposal
296    ///
297    /// Returns `false` in only one case: when an `op` is set to `remove` tries to
298    /// remove a non-existing component, which is a soft-error in itself
299    pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
300        match update.op {
301            ApplicationDataUpdateOperation::Update { update: data } => {
302                *self.component_data.entry(update.component_id).or_default() = data;
303                true
304            }
305            ApplicationDataUpdateOperation::Remove => {
306                self.component_data.remove(&update.component_id).is_some()
307            }
308        }
309    }
310}
311
312impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
313    fn from(val: ApplicationDataDictionary) -> Self {
314        crate::group::extensions::Extension::ApplicationData(val)
315    }
316}
317
318#[derive(
319    Debug,
320    Clone,
321    PartialEq,
322    Eq,
323    tls_codec::TlsSerialize,
324    tls_codec::TlsDeserialize,
325    tls_codec::TlsSize,
326)]
327#[repr(u8)]
328#[cfg_attr(
329    feature = "serde",
330    derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
331)]
332pub enum ApplicationDataUpdateOperationType {
333    Invalid = 0x00,
334    Update = 0x01,
335    Remove = 0x02,
336}
337
338#[derive(
339    Debug,
340    Clone,
341    PartialEq,
342    Eq,
343    tls_codec::TlsSerialize,
344    tls_codec::TlsDeserialize,
345    tls_codec::TlsSize,
346)]
347#[repr(u8)]
348#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
349pub enum ApplicationDataUpdateOperation {
350    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
351    Update {
352        #[tls_codec(with = "crate::tlspl::bytes")]
353        update: Vec<u8>,
354    },
355    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
356    Remove,
357}
358
359#[derive(
360    Debug,
361    Clone,
362    PartialEq,
363    Eq,
364    tls_codec::TlsSerialize,
365    tls_codec::TlsDeserialize,
366    tls_codec::TlsSize,
367)]
368#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
369pub struct AppDataUpdate {
370    pub component_id: ComponentId,
371    pub op: ApplicationDataUpdateOperation,
372}
373
374impl AppDataUpdate {
375    /// Allows to extract a concrete `Component` from an update operation
376    ///
377    /// Returns Ok(None) if the update is a `Remove` operation
378    /// Otherwise returns Ok(Some(C)) unless an error occurs
379    pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
380        let type_component_id = C::component_id();
381        if type_component_id != self.component_id {
382            return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
383                expected: type_component_id,
384                actual: self.component_id,
385            });
386        }
387
388        let ApplicationDataUpdateOperation::Update { update } = &self.op else {
389            return Ok(None);
390        };
391
392        Ok(Some(C::from_tls_bytes(update)?))
393    }
394}
395
396#[derive(
397    Debug,
398    Clone,
399    PartialEq,
400    Eq,
401    tls_codec::TlsSerialize,
402    tls_codec::TlsDeserialize,
403    tls_codec::TlsSize,
404)]
405#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
406pub struct ApplicationData {
407    pub component_id: ComponentId,
408    #[tls_codec(with = "crate::tlspl::bytes")]
409    pub data: Vec<u8>,
410}
411
412pub type AppEphemeral = ApplicationData;
413
414#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
415#[cfg_attr(feature = "serde", derive(serde::Serialize))]
416#[cfg_attr(feature = "serde", serde(transparent))]
417pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
418
419impl<'a> SafeAadItemRef<'a> {
420    pub fn component_id(&self) -> &ComponentId {
421        self.0.component_id
422    }
423
424    pub fn aad_item_data(&self) -> &[u8] {
425        self.0.data
426    }
427
428    pub fn from_item_data<C: Component>(
429        component_id: &'a ComponentId,
430        aad_item_data: &'a [u8],
431    ) -> Option<Self> {
432        (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
433            component_id,
434            data: aad_item_data,
435        }))
436    }
437}
438
439#[derive(
440    Debug,
441    Clone,
442    PartialEq,
443    Eq,
444    Hash,
445    tls_codec::TlsSerialize,
446    tls_codec::TlsDeserialize,
447    tls_codec::TlsSize,
448)]
449#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
450#[cfg_attr(feature = "serde", serde(transparent))]
451pub struct SafeAadItem(ComponentData);
452
453impl SafeAadItem {
454    pub fn as_ref(&self) -> SafeAadItemRef<'_> {
455        SafeAadItemRef(self.0.as_ref())
456    }
457}
458
459#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
460#[cfg_attr(feature = "serde", derive(serde::Serialize))]
461pub struct SafeAadRef<'a> {
462    pub aad_items: &'a [&'a SafeAadItemRef<'a>],
463}
464
465impl SafeAadRef<'_> {
466    pub fn is_ordered_and_unique(&self) -> bool {
467        let mut iter = self.aad_items.iter().peekable();
468
469        while let Some(item) = iter.next() {
470            let Some(next) = iter.peek() else {
471                continue;
472            };
473
474            if item.component_id() >= next.component_id() {
475                return false;
476            }
477        }
478
479        true
480    }
481}
482
483impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
484    fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
485        Self { aad_items }
486    }
487}
488
489#[derive(
490    Debug,
491    Default,
492    Clone,
493    PartialEq,
494    Eq,
495    Hash,
496    tls_codec::TlsSerialize,
497    tls_codec::TlsDeserialize,
498    tls_codec::TlsSize,
499)]
500#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
501pub struct SafeAad {
502    aad_items: ComponentDataMap,
503}
504
505impl SafeAad {
506    pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
507        self.aad_items
508            .iter()
509            .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
510    }
511
512    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
513        self.aad_items.extract_component::<C>()
514    }
515
516    /// Returns `true` if newly inserted
517    pub fn insert_or_update_component<C: Component>(
518        &mut self,
519        component: &C,
520    ) -> crate::MlsSpecResult<bool> {
521        self.aad_items.insert_or_update_component(component)
522    }
523}
524
525#[derive(
526    Debug,
527    Clone,
528    PartialEq,
529    Eq,
530    Hash,
531    tls_codec::TlsSerialize,
532    tls_codec::TlsDeserialize,
533    tls_codec::TlsSize,
534)]
535#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
536pub struct WireFormats {
537    pub wire_formats: Vec<crate::defs::WireFormat>,
538}
539
540#[derive(
541    Debug,
542    Clone,
543    PartialEq,
544    Eq,
545    tls_codec::TlsSerialize,
546    tls_codec::TlsDeserialize,
547    tls_codec::TlsSize,
548)]
549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
550pub struct ComponentsList {
551    pub component_ids: Vec<ComponentId>,
552}
553
554#[derive(
555    Debug,
556    Clone,
557    PartialEq,
558    Eq,
559    tls_codec::TlsSerialize,
560    tls_codec::TlsDeserialize,
561    tls_codec::TlsSize,
562)]
563pub struct AppComponents(pub ComponentsList);
564
565impl Component for AppComponents {
566    fn component_id() -> ComponentId {
567        super::APP_COMPONENTS_ID
568    }
569}
570
571#[derive(
572    Debug,
573    Clone,
574    PartialEq,
575    Eq,
576    tls_codec::TlsSerialize,
577    tls_codec::TlsDeserialize,
578    tls_codec::TlsSize,
579)]
580pub struct SafeAadComponent(pub ComponentsList);
581
582impl Component for SafeAadComponent {
583    fn component_id() -> ComponentId {
584        super::SAFE_AAD_ID
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use std::collections::BTreeMap;
591
592    use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
593    use crate::{
594        drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
595        generate_roundtrip_test,
596    };
597
598    generate_roundtrip_test!(can_roundtrip_appdatadict, {
599        ApplicationDataDictionary {
600            component_data: super::ComponentDataMap(BTreeMap::from([
601                (1, vec![1]),
602                (3, vec![3]),
603                (2, vec![2]),
604            ])),
605        }
606    });
607
608    generate_roundtrip_test!(can_roundtrip_safeaad, {
609        SafeAad {
610            aad_items: super::ComponentDataMap(BTreeMap::from([
611                (1, vec![1]),
612                (3, vec![3]),
613                (2, vec![2]),
614            ])),
615        }
616    });
617
618    #[test]
619    fn can_build_safe_aad() {
620        let mut safe_aad = SafeAad::default();
621        safe_aad
622            .insert_or_update_component(&LastResortKeyPackage)
623            .unwrap();
624
625        let cid = LastResortKeyPackage::component_id();
626        let aad_item_ref =
627            SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
628
629        let items = &[&aad_item_ref];
630        let safe_ref = SafeAadRef::from(items.as_slice());
631        assert!(safe_ref.is_ordered_and_unique());
632    }
633}