Skip to main content

mls_spec/drafts/mls_extensions/
safe_application.rs

1use std::collections::BTreeMap;
2
3use crate::{SensitiveBytes, key_schedule::PreSharedKeyId};
4
5pub type ComponentId = u16;
6
7pub const COMPONENT_ID_GREASE_VALUES: [ComponentId; 8] = [
8    0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A,
9];
10
11pub trait Component: crate::Parsable + crate::Serializable {
12    fn component_id() -> ComponentId;
13
14    fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
15        PreSharedKeyId {
16            psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
17                crate::key_schedule::ApplicationPsk {
18                    component_id: Self::component_id(),
19                    psk_id,
20                },
21            ),
22            psk_nonce,
23        }
24    }
25
26    fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
27        Ok(ComponentData {
28            component_id: Self::component_id(),
29            data: self.to_tls_bytes()?,
30        })
31    }
32}
33
34#[derive(
35    Debug, Clone, Default, PartialEq, Eq, strum::IntoStaticStr, strum::EnumString, strum::Display,
36)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[repr(u8)]
39pub enum ComponentOperationBaseLabel {
40    #[default]
41    #[strum(serialize = "MLS Component")]
42    MlsComponent,
43    /// Other cases. Unlikely to ever happen but whatever!
44    Custom(String),
45}
46
47impl tls_codec::Size for ComponentOperationBaseLabel {
48    fn tls_serialized_len(&self) -> usize {
49        crate::tlspl::string::tls_serialized_len(self.into())
50    }
51}
52
53impl tls_codec::Serialize for ComponentOperationBaseLabel {
54    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
55        crate::tlspl::string::tls_serialize(self.into(), writer)
56    }
57}
58
59impl tls_codec::Deserialize for ComponentOperationBaseLabel {
60    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
61    where
62        Self: Sized,
63    {
64        let raw_str = crate::tlspl::string::tls_deserialize(bytes)?;
65        Ok(Self::try_from(raw_str.as_str()).unwrap_or(Self::Custom(raw_str)))
66    }
67}
68
69#[derive(
70    Debug,
71    Clone,
72    PartialEq,
73    Eq,
74    tls_codec::TlsSerialize,
75    tls_codec::TlsDeserialize,
76    tls_codec::TlsSize,
77)]
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79pub struct ComponentOperationLabel {
80    pub base_label: ComponentOperationBaseLabel,
81    pub component_id: ComponentId,
82    #[tls_codec(with = "crate::tlspl::bytes")]
83    pub label: Vec<u8>,
84}
85
86#[derive(
87    Debug,
88    Clone,
89    PartialEq,
90    Eq,
91    Hash,
92    tls_codec::TlsSerialize,
93    tls_codec::TlsDeserialize,
94    tls_codec::TlsSize,
95)]
96#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
97pub struct ComponentData {
98    pub component_id: ComponentId,
99    #[tls_codec(with = "crate::tlspl::bytes")]
100    pub data: Vec<u8>,
101}
102
103impl ComponentData {
104    pub fn as_ref(&self) -> ComponentDataRef<'_> {
105        ComponentDataRef {
106            component_id: &self.component_id,
107            data: &self.data,
108        }
109    }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize))]
114pub struct ComponentDataRef<'a> {
115    pub component_id: &'a ComponentId,
116    #[tls_codec(with = "crate::tlspl::bytes")]
117    pub data: &'a [u8],
118}
119
120/// Utilitary struct that contains a `BTreeMap` in order to preserve ordering and unicity
121///
122/// Also takes extra care to make sure that the `serde` representation when serialized
123/// is equivalent to the TLS-PL version of it
124#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126#[cfg_attr(
127    feature = "serde",
128    serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
129)]
130pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
131
132impl ComponentDataMap {
133    fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
134        self.0
135            .get(&C::component_id())
136            .map(|data| C::from_tls_bytes(data))
137            .transpose()
138    }
139
140    fn insert_or_update_component<C: Component>(
141        &mut self,
142        component: &C,
143    ) -> crate::MlsSpecResult<bool> {
144        // This is put before to make sure we don't error out on serialization before modifying the map
145        let component_data = component.to_tls_bytes()?;
146        match self.0.entry(C::component_id()) {
147            std::collections::btree_map::Entry::Vacant(vacant_entry) => {
148                vacant_entry.insert(component_data);
149                Ok(true)
150            }
151            std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
152                *(occupied_entry.get_mut()) = component_data;
153                Ok(false)
154            }
155        }
156    }
157
158    fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
159        self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
160    }
161}
162
163impl tls_codec::Size for ComponentDataMap {
164    fn tls_serialized_len(&self) -> usize {
165        crate::tlspl::tls_serialized_len_as_vlvec(
166            self.iter()
167                .map(|(component_id, data)| {
168                    ComponentDataRef { component_id, data }.tls_serialized_len()
169                })
170                .sum(),
171        )
172    }
173}
174
175impl tls_codec::Deserialize for ComponentDataMap {
176    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
177    where
178        Self: Sized,
179    {
180        let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
181
182        Ok(Self(BTreeMap::from_iter(
183            tlspl_value
184                .into_iter()
185                .map(|cdata| (cdata.component_id, cdata.data)),
186        )))
187    }
188}
189
190impl tls_codec::Serialize for ComponentDataMap {
191    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
192        // TODO: Improve this by not allocating a vec of refs
193        self.iter()
194            .map(|(component_id, data)| ComponentDataRef { component_id, data })
195            .collect::<Vec<ComponentDataRef>>()
196            .tls_serialize(writer)
197    }
198}
199
200impl std::ops::Deref for ComponentDataMap {
201    type Target = BTreeMap<ComponentId, Vec<u8>>;
202    fn deref(&self) -> &Self::Target {
203        &self.0
204    }
205}
206
207impl std::ops::DerefMut for ComponentDataMap {
208    fn deref_mut(&mut self) -> &mut Self::Target {
209        &mut self.0
210    }
211}
212
213impl From<Vec<ComponentData>> for ComponentDataMap {
214    fn from(value: Vec<ComponentData>) -> Self {
215        Self(BTreeMap::from_iter(
216            value
217                .into_iter()
218                .map(|component| (component.component_id, component.data)),
219        ))
220    }
221}
222
223#[allow(clippy::from_over_into)]
224impl Into<Vec<ComponentData>> for ComponentDataMap {
225    fn into(self) -> Vec<ComponentData> {
226        self.0
227            .into_iter()
228            .map(|(component_id, data)| ComponentData { component_id, data })
229            .collect()
230    }
231}
232
233/// Please note that this ApplicationDataDictionary is backed by a `BTreeMap` to
234/// take care of ordering and deduplication automatically.
235///
236/// The conversion from/to a `Vec<ComponentData>` is done at serialization/deserialization time
237#[derive(
238    Debug,
239    Default,
240    Clone,
241    PartialEq,
242    Eq,
243    Hash,
244    tls_codec::TlsSize,
245    tls_codec::TlsDeserialize,
246    tls_codec::TlsSerialize,
247)]
248#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
249pub struct ApplicationDataDictionary {
250    pub component_data: ComponentDataMap,
251}
252
253impl ApplicationDataDictionary {
254    pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
255        self.component_data
256            .iter()
257            .map(|(component_id, data)| ComponentDataRef { component_id, data })
258    }
259
260    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
261        self.component_data.extract_component::<C>()
262    }
263
264    /// Returns `true` if newly inserted
265    pub fn insert_or_update_component<C: Component>(
266        &mut self,
267        component: &C,
268    ) -> crate::MlsSpecResult<bool> {
269        self.component_data.insert_or_update_component(component)
270    }
271
272    /// Applies an ApplicationDataUpdate proposal
273    ///
274    /// Returns `false` in only one case: when an `op` is set to `remove` tries to
275    /// remove a non-existing component, which is a soft-error in itself
276    pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
277        match update.op {
278            ApplicationDataUpdateOperation::Update { update: data } => {
279                *self.component_data.entry(update.component_id).or_default() = data;
280                true
281            }
282            ApplicationDataUpdateOperation::Remove => {
283                self.component_data.remove(&update.component_id).is_some()
284            }
285        }
286    }
287}
288
289impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
290    fn from(val: ApplicationDataDictionary) -> Self {
291        crate::group::extensions::Extension::ApplicationData(val)
292    }
293}
294
295#[derive(
296    Debug,
297    Clone,
298    PartialEq,
299    Eq,
300    tls_codec::TlsSerialize,
301    tls_codec::TlsDeserialize,
302    tls_codec::TlsSize,
303)]
304#[repr(u8)]
305#[cfg_attr(
306    feature = "serde",
307    derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
308)]
309pub enum ApplicationDataUpdateOperationType {
310    Invalid = 0x00,
311    Update = 0x01,
312    Remove = 0x02,
313}
314
315#[derive(
316    Debug,
317    Clone,
318    PartialEq,
319    Eq,
320    tls_codec::TlsSerialize,
321    tls_codec::TlsDeserialize,
322    tls_codec::TlsSize,
323)]
324#[repr(u8)]
325#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
326pub enum ApplicationDataUpdateOperation {
327    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
328    Update {
329        #[tls_codec(with = "crate::tlspl::bytes")]
330        update: Vec<u8>,
331    },
332    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
333    Remove,
334}
335
336#[derive(
337    Debug,
338    Clone,
339    PartialEq,
340    Eq,
341    tls_codec::TlsSerialize,
342    tls_codec::TlsDeserialize,
343    tls_codec::TlsSize,
344)]
345#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
346pub struct AppDataUpdate {
347    pub component_id: ComponentId,
348    pub op: ApplicationDataUpdateOperation,
349}
350
351impl AppDataUpdate {
352    /// Allows to extract a concrete `Component` from an update operation
353    ///
354    /// Returns Ok(None) if the update is a `Remove` operation
355    /// Otherwise returns Ok(Some(C)) unless an error occurs
356    pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
357        let type_component_id = C::component_id();
358        if type_component_id != self.component_id {
359            return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
360                expected: type_component_id,
361                actual: self.component_id,
362            });
363        }
364
365        let ApplicationDataUpdateOperation::Update { update } = &self.op else {
366            return Ok(None);
367        };
368
369        Ok(Some(C::from_tls_bytes(update)?))
370    }
371}
372
373#[derive(
374    Debug,
375    Clone,
376    PartialEq,
377    Eq,
378    tls_codec::TlsSerialize,
379    tls_codec::TlsDeserialize,
380    tls_codec::TlsSize,
381)]
382#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
383pub struct ApplicationData {
384    pub component_id: ComponentId,
385    #[tls_codec(with = "crate::tlspl::bytes")]
386    pub data: Vec<u8>,
387}
388
389pub type AppEphemeral = ApplicationData;
390
391#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
392#[cfg_attr(feature = "serde", derive(serde::Serialize))]
393#[cfg_attr(feature = "serde", serde(transparent))]
394pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
395
396impl<'a> SafeAadItemRef<'a> {
397    pub fn component_id(&self) -> &ComponentId {
398        self.0.component_id
399    }
400
401    pub fn aad_item_data(&self) -> &[u8] {
402        self.0.data
403    }
404
405    pub fn from_item_data<C: Component>(
406        component_id: &'a ComponentId,
407        aad_item_data: &'a [u8],
408    ) -> Option<Self> {
409        (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
410            component_id,
411            data: aad_item_data,
412        }))
413    }
414}
415
416#[derive(
417    Debug,
418    Clone,
419    PartialEq,
420    Eq,
421    Hash,
422    tls_codec::TlsSerialize,
423    tls_codec::TlsDeserialize,
424    tls_codec::TlsSize,
425)]
426#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
427#[cfg_attr(feature = "serde", serde(transparent))]
428pub struct SafeAadItem(ComponentData);
429
430impl SafeAadItem {
431    pub fn as_ref(&self) -> SafeAadItemRef<'_> {
432        SafeAadItemRef(self.0.as_ref())
433    }
434}
435
436#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
437#[cfg_attr(feature = "serde", derive(serde::Serialize))]
438pub struct SafeAadRef<'a> {
439    pub aad_items: &'a [&'a SafeAadItemRef<'a>],
440}
441
442impl SafeAadRef<'_> {
443    pub fn is_ordered_and_unique(&self) -> bool {
444        let mut iter = self.aad_items.iter().peekable();
445
446        while let Some(item) = iter.next() {
447            let Some(next) = iter.peek() else {
448                continue;
449            };
450
451            if item.component_id() >= next.component_id() {
452                return false;
453            }
454        }
455
456        true
457    }
458}
459
460impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
461    fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
462        Self { aad_items }
463    }
464}
465
466#[derive(
467    Debug,
468    Default,
469    Clone,
470    PartialEq,
471    Eq,
472    Hash,
473    tls_codec::TlsSerialize,
474    tls_codec::TlsDeserialize,
475    tls_codec::TlsSize,
476)]
477#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
478pub struct SafeAad {
479    aad_items: ComponentDataMap,
480}
481
482impl SafeAad {
483    pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
484        self.aad_items
485            .iter()
486            .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
487    }
488
489    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
490        self.aad_items.extract_component::<C>()
491    }
492
493    /// Returns `true` if newly inserted
494    pub fn insert_or_update_component<C: Component>(
495        &mut self,
496        component: &C,
497    ) -> crate::MlsSpecResult<bool> {
498        self.aad_items.insert_or_update_component(component)
499    }
500}
501
502#[derive(
503    Debug,
504    Clone,
505    PartialEq,
506    Eq,
507    Hash,
508    tls_codec::TlsSerialize,
509    tls_codec::TlsDeserialize,
510    tls_codec::TlsSize,
511)]
512#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
513pub struct WireFormats {
514    pub wire_formats: Vec<crate::defs::WireFormat>,
515}
516
517#[derive(
518    Debug,
519    Clone,
520    PartialEq,
521    Eq,
522    tls_codec::TlsSerialize,
523    tls_codec::TlsDeserialize,
524    tls_codec::TlsSize,
525)]
526#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
527pub struct ComponentsList {
528    pub component_ids: Vec<ComponentId>,
529}
530
531#[derive(
532    Debug,
533    Clone,
534    PartialEq,
535    Eq,
536    tls_codec::TlsSerialize,
537    tls_codec::TlsDeserialize,
538    tls_codec::TlsSize,
539)]
540pub struct AppComponents(pub ComponentsList);
541
542impl Component for AppComponents {
543    fn component_id() -> ComponentId {
544        super::APP_COMPONENTS_ID
545    }
546}
547
548#[derive(
549    Debug,
550    Clone,
551    PartialEq,
552    Eq,
553    tls_codec::TlsSerialize,
554    tls_codec::TlsDeserialize,
555    tls_codec::TlsSize,
556)]
557pub struct SafeAadComponent(pub ComponentsList);
558
559impl Component for SafeAadComponent {
560    fn component_id() -> ComponentId {
561        super::SAFE_AAD_ID
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use std::collections::BTreeMap;
568
569    use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
570    use crate::{
571        drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
572        generate_roundtrip_test,
573    };
574
575    generate_roundtrip_test!(can_roundtrip_appdatadict, {
576        ApplicationDataDictionary {
577            component_data: super::ComponentDataMap(BTreeMap::from([
578                (1, vec![1]),
579                (3, vec![3]),
580                (2, vec![2]),
581            ])),
582        }
583    });
584
585    generate_roundtrip_test!(can_roundtrip_safeaad, {
586        SafeAad {
587            aad_items: super::ComponentDataMap(BTreeMap::from([
588                (1, vec![1]),
589                (3, vec![3]),
590                (2, vec![2]),
591            ])),
592        }
593    });
594
595    #[test]
596    fn can_build_safe_aad() {
597        let mut safe_aad = SafeAad::default();
598        safe_aad
599            .insert_or_update_component(&LastResortKeyPackage)
600            .unwrap();
601
602        let cid = LastResortKeyPackage::component_id();
603        let aad_item_ref =
604            SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
605
606        let items = &[&aad_item_ref];
607        let safe_ref = SafeAadRef::from(items.as_slice());
608        assert!(safe_ref.is_ordered_and_unique());
609    }
610}