mls_spec/drafts/mls_extensions/
safe_application.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    SensitiveBytes, ToPrefixedLabel,
5    defs::{
6        ProtocolVersion,
7        labels::{PublicKeyEncryptionLabel, SignatureLabel},
8    },
9    key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15    fn component_id() -> ComponentId;
16
17    fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18        PreSharedKeyId {
19            psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20                crate::key_schedule::ApplicationPsk {
21                    component_id: Self::component_id(),
22                    psk_id,
23                },
24            ),
25            psk_nonce,
26        }
27    }
28
29    fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30        Ok(ComponentData {
31            component_id: Self::component_id(),
32            data: self.to_tls_bytes()?,
33        })
34    }
35}
36
37#[derive(
38    Debug,
39    Clone,
40    PartialEq,
41    Eq,
42    tls_codec::TlsSerialize,
43    tls_codec::TlsDeserialize,
44    tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48    #[tls_codec(with = "crate::tlspl::bytes")]
49    label: Vec<u8>,
50    pub component_id: ComponentId,
51    #[tls_codec(with = "crate::tlspl::bytes")]
52    pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
57    pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58        Self {
59            label: PublicKeyEncryptionLabel::SafeApp
60                .to_prefixed_string(ProtocolVersion::default())
61                .into_bytes(),
62            component_id,
63            context,
64        }
65    }
66
67    /// This follows the spec for the following construct: <https://www.ietf.org/archive/id/draft-ietf-mls-extensions-06.html#name-hybrid-public-key-encryptio>
68    pub fn new_for_signature(
69        label: SignatureLabel,
70        component_id: ComponentId,
71        context: Vec<u8>,
72    ) -> (Self, SignatureLabel) {
73        let col = Self {
74            label: label
75                .to_prefixed_string(ProtocolVersion::default())
76                .into_bytes(),
77            component_id,
78            context,
79        };
80
81        (col, SignatureLabel::ComponentOperationLabel)
82    }
83}
84
85#[derive(
86    Debug,
87    Clone,
88    PartialEq,
89    Eq,
90    tls_codec::TlsSerialize,
91    tls_codec::TlsDeserialize,
92    tls_codec::TlsSize,
93)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct ComponentData {
96    pub component_id: ComponentId,
97    #[tls_codec(with = "crate::tlspl::bytes")]
98    pub data: Vec<u8>,
99}
100
101impl ComponentData {
102    pub fn as_ref(&self) -> ComponentDataRef {
103        ComponentDataRef {
104            component_id: &self.component_id,
105            data: &self.data,
106        }
107    }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize))]
112pub struct ComponentDataRef<'a> {
113    pub component_id: &'a ComponentId,
114    #[tls_codec(with = "crate::tlspl::bytes")]
115    pub data: &'a [u8],
116}
117
118/// Utilitary struct that contains a `BTreeMap` in order to preserve ordering and unicity
119///
120/// Also takes extra care to make sure that the `serde` representation when serialized
121/// is equivalent to the TLS-PL version of it
122#[derive(Debug, Default, Clone, PartialEq, Eq)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[cfg_attr(
125    feature = "serde",
126    serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
127)]
128pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
129
130impl ComponentDataMap {
131    fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
132        self.0
133            .get(&C::component_id())
134            .map(|data| C::from_tls_bytes(data))
135            .transpose()
136    }
137
138    fn insert_or_update_component<C: Component>(
139        &mut self,
140        component: &C,
141    ) -> crate::MlsSpecResult<bool> {
142        // This is put before to make sure we don't error out on serialization before modifying the map
143        let component_data = component.to_tls_bytes()?;
144        match self.0.entry(C::component_id()) {
145            std::collections::btree_map::Entry::Vacant(vacant_entry) => {
146                vacant_entry.insert(component_data);
147                Ok(true)
148            }
149            std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
150                *(occupied_entry.get_mut()) = component_data;
151                Ok(false)
152            }
153        }
154    }
155
156    fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
157        self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
158    }
159}
160
161impl tls_codec::Size for ComponentDataMap {
162    fn tls_serialized_len(&self) -> usize {
163        crate::tlspl::tls_serialized_len_as_vlvec(
164            self.iter()
165                .map(|(component_id, data)| {
166                    ComponentDataRef { component_id, data }.tls_serialized_len()
167                })
168                .sum(),
169        )
170    }
171}
172
173impl tls_codec::Deserialize for ComponentDataMap {
174    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
175    where
176        Self: Sized,
177    {
178        let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
179
180        Ok(Self(BTreeMap::from_iter(
181            tlspl_value
182                .into_iter()
183                .map(|cdata| (cdata.component_id, cdata.data)),
184        )))
185    }
186}
187
188impl tls_codec::Serialize for ComponentDataMap {
189    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
190        // TODO: Improve this by not allocating a vec of refs
191        self.iter()
192            .map(|(component_id, data)| ComponentDataRef { component_id, data })
193            .collect::<Vec<ComponentDataRef>>()
194            .tls_serialize(writer)
195    }
196}
197
198impl std::ops::Deref for ComponentDataMap {
199    type Target = BTreeMap<ComponentId, Vec<u8>>;
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205impl std::ops::DerefMut for ComponentDataMap {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.0
208    }
209}
210
211impl From<Vec<ComponentData>> for ComponentDataMap {
212    fn from(value: Vec<ComponentData>) -> Self {
213        Self(BTreeMap::from_iter(
214            value
215                .into_iter()
216                .map(|component| (component.component_id, component.data)),
217        ))
218    }
219}
220
221#[allow(clippy::from_over_into)]
222impl Into<Vec<ComponentData>> for ComponentDataMap {
223    fn into(self) -> Vec<ComponentData> {
224        self.0
225            .into_iter()
226            .map(|(component_id, data)| ComponentData { component_id, data })
227            .collect()
228    }
229}
230
231/// Please note that this ApplicationDataDictionary is backed by a `BTreeMap` to
232/// take care of ordering and deduplication automatically.
233///
234/// The conversion from/to a `Vec<ComponentData>` is done at serialization/deserialization time
235#[derive(
236    Debug,
237    Default,
238    Clone,
239    PartialEq,
240    Eq,
241    tls_codec::TlsSize,
242    tls_codec::TlsDeserialize,
243    tls_codec::TlsSerialize,
244)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub struct ApplicationDataDictionary {
247    pub component_data: ComponentDataMap,
248}
249
250impl ApplicationDataDictionary {
251    pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef> {
252        self.component_data
253            .iter()
254            .map(|(component_id, data)| ComponentDataRef { component_id, data })
255    }
256
257    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
258        self.component_data.extract_component::<C>()
259    }
260
261    /// Returns `true` if newly inserted
262    pub fn insert_or_update_component<C: Component>(
263        &mut self,
264        component: &C,
265    ) -> crate::MlsSpecResult<bool> {
266        self.component_data.insert_or_update_component(component)
267    }
268
269    /// Applies an ApplicationDataUpdate proposal
270    ///
271    /// Returns `false` in only one case: when an `op` is set to `remove` tries to
272    /// remove a non-existing component, which is a soft-error in itself
273    pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
274        match update.op {
275            ApplicationDataUpdateOperation::Update { update: data } => {
276                *self.component_data.entry(update.component_id).or_default() = data;
277                true
278            }
279            ApplicationDataUpdateOperation::Remove => {
280                self.component_data.remove(&update.component_id).is_some()
281            }
282        }
283    }
284}
285
286impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
287    fn from(val: ApplicationDataDictionary) -> Self {
288        crate::group::extensions::Extension::ApplicationData(val)
289    }
290}
291
292#[derive(
293    Debug,
294    Clone,
295    PartialEq,
296    Eq,
297    tls_codec::TlsSerialize,
298    tls_codec::TlsDeserialize,
299    tls_codec::TlsSize,
300)]
301#[repr(u8)]
302#[cfg_attr(
303    feature = "serde",
304    derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
305)]
306pub enum ApplicationDataUpdateOperationType {
307    Invalid = 0x00,
308    Update = 0x01,
309    Remove = 0x02,
310}
311
312#[derive(
313    Debug,
314    Clone,
315    PartialEq,
316    Eq,
317    tls_codec::TlsSerialize,
318    tls_codec::TlsDeserialize,
319    tls_codec::TlsSize,
320)]
321#[repr(u8)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323pub enum ApplicationDataUpdateOperation {
324    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
325    Update {
326        #[tls_codec(with = "crate::tlspl::bytes")]
327        update: Vec<u8>,
328    },
329    #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
330    Remove,
331}
332
333#[derive(
334    Debug,
335    Clone,
336    PartialEq,
337    Eq,
338    tls_codec::TlsSerialize,
339    tls_codec::TlsDeserialize,
340    tls_codec::TlsSize,
341)]
342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
343pub struct AppDataUpdate {
344    pub component_id: ComponentId,
345    pub op: ApplicationDataUpdateOperation,
346}
347
348impl AppDataUpdate {
349    /// Allows to extract a concrete `Component` from an update operation
350    ///
351    /// Returns Ok(None) if the update is a `Remove` operation
352    /// Otherwise returns Ok(Some(C)) unless an error occurs
353    pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
354        let type_component_id = C::component_id();
355        if type_component_id != self.component_id {
356            return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
357                expected: type_component_id,
358                actual: self.component_id,
359            });
360        }
361
362        let ApplicationDataUpdateOperation::Update { update } = &self.op else {
363            return Ok(None);
364        };
365
366        Ok(Some(C::from_tls_bytes(update)?))
367    }
368}
369
370#[derive(
371    Debug,
372    Clone,
373    PartialEq,
374    Eq,
375    tls_codec::TlsSerialize,
376    tls_codec::TlsDeserialize,
377    tls_codec::TlsSize,
378)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct ApplicationData {
381    pub component_id: ComponentId,
382    #[tls_codec(with = "crate::tlspl::bytes")]
383    pub data: Vec<u8>,
384}
385
386pub type AppEphemeral = ApplicationData;
387
388#[derive(Debug, Clone, PartialEq, Eq, Hash, tls_codec::TlsSerialize, tls_codec::TlsSize)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize))]
390pub struct SafeAadItemRefOld<'a> {
391    pub component_id: &'a ComponentId,
392    #[tls_codec(with = "crate::tlspl::bytes")]
393    pub aad_item_data: &'a [u8],
394}
395
396#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
397#[cfg_attr(feature = "serde", derive(serde::Serialize))]
398#[cfg_attr(feature = "serde", serde(transparent))]
399pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
400
401impl SafeAadItemRef<'_> {
402    pub fn component_id(&self) -> &ComponentId {
403        self.0.component_id
404    }
405
406    pub fn aad_item_data(&self) -> &[u8] {
407        self.0.data
408    }
409}
410
411#[derive(
412    Debug,
413    Clone,
414    PartialEq,
415    Eq,
416    tls_codec::TlsSerialize,
417    tls_codec::TlsDeserialize,
418    tls_codec::TlsSize,
419)]
420#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
421#[cfg_attr(feature = "serde", serde(transparent))]
422pub struct SafeAadItem(ComponentData);
423
424impl SafeAadItem {
425    pub fn as_ref(&self) -> SafeAadItemRef {
426        SafeAadItemRef(self.0.as_ref())
427    }
428}
429
430#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
431#[cfg_attr(feature = "serde", derive(serde::Serialize))]
432pub struct SafeAadRef<'a> {
433    pub aad_items: &'a [&'a SafeAadItemRef<'a>],
434}
435
436#[derive(
437    Debug,
438    Clone,
439    PartialEq,
440    Eq,
441    tls_codec::TlsSerialize,
442    tls_codec::TlsDeserialize,
443    tls_codec::TlsSize,
444)]
445#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
446pub struct SafeAad {
447    aad_items: ComponentDataMap,
448}
449
450impl SafeAad {
451    pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef> {
452        self.aad_items
453            .iter()
454            .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
455    }
456
457    pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
458        self.aad_items.extract_component::<C>()
459    }
460
461    /// Returns `true` if newly inserted
462    pub fn insert_or_update_component<C: Component>(
463        &mut self,
464        component: &C,
465    ) -> crate::MlsSpecResult<bool> {
466        self.aad_items.insert_or_update_component(component)
467    }
468}
469
470#[derive(
471    Debug,
472    Clone,
473    PartialEq,
474    Eq,
475    tls_codec::TlsSerialize,
476    tls_codec::TlsDeserialize,
477    tls_codec::TlsSize,
478)]
479#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
480pub struct WireFormats {
481    pub wire_formats: Vec<crate::defs::WireFormat>,
482}
483
484#[derive(
485    Debug,
486    Clone,
487    PartialEq,
488    Eq,
489    tls_codec::TlsSerialize,
490    tls_codec::TlsDeserialize,
491    tls_codec::TlsSize,
492)]
493#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
494pub struct ComponentsList {
495    pub component_ids: Vec<ComponentId>,
496}
497
498#[cfg(test)]
499mod tests {
500    use std::collections::BTreeMap;
501
502    use super::{ApplicationDataDictionary, SafeAad};
503    use crate::generate_roundtrip_test;
504
505    generate_roundtrip_test!(can_roundtrip_appdatadict, {
506        ApplicationDataDictionary {
507            component_data: super::ComponentDataMap(BTreeMap::from([
508                (1, vec![1]),
509                (3, vec![3]),
510                (2, vec![2]),
511            ])),
512        }
513    });
514
515    generate_roundtrip_test!(can_roundtrip_safeaad, {
516        SafeAad {
517            aad_items: super::ComponentDataMap(BTreeMap::from([
518                (1, vec![1]),
519                (3, vec![3]),
520                (2, vec![2]),
521            ])),
522        }
523    });
524}