1use std::collections::BTreeMap;
2
3use crate::{
4 SensitiveBytes, ToPrefixedLabel,
5 defs::{
6 ProtocolVersion,
7 labels::{PublicKeyEncryptionLabel, SignatureLabel},
8 },
9 key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15 fn component_id() -> ComponentId;
16
17 fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18 PreSharedKeyId {
19 psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20 crate::key_schedule::ApplicationPsk {
21 component_id: Self::component_id(),
22 psk_id,
23 },
24 ),
25 psk_nonce,
26 }
27 }
28
29 fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30 Ok(ComponentData {
31 component_id: Self::component_id(),
32 data: self.to_tls_bytes()?,
33 })
34 }
35}
36
37#[derive(
38 Debug,
39 Clone,
40 PartialEq,
41 Eq,
42 tls_codec::TlsSerialize,
43 tls_codec::TlsDeserialize,
44 tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48 #[tls_codec(with = "crate::tlspl::bytes")]
49 label: Vec<u8>,
50 pub component_id: ComponentId,
51 #[tls_codec(with = "crate::tlspl::bytes")]
52 pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56 pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58 Self {
59 label: PublicKeyEncryptionLabel::SafeApp
60 .to_prefixed_string(ProtocolVersion::default())
61 .into_bytes(),
62 component_id,
63 context,
64 }
65 }
66
67 pub fn new_for_signature(
69 label: SignatureLabel,
70 component_id: ComponentId,
71 context: Vec<u8>,
72 ) -> (Self, SignatureLabel) {
73 let col = Self {
74 label: label
75 .to_prefixed_string(ProtocolVersion::default())
76 .into_bytes(),
77 component_id,
78 context,
79 };
80
81 (col, SignatureLabel::ComponentOperationLabel)
82 }
83}
84
85#[derive(
86 Debug,
87 Clone,
88 PartialEq,
89 Eq,
90 tls_codec::TlsSerialize,
91 tls_codec::TlsDeserialize,
92 tls_codec::TlsSize,
93)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct ComponentData {
96 pub component_id: ComponentId,
97 #[tls_codec(with = "crate::tlspl::bytes")]
98 pub data: Vec<u8>,
99}
100
101impl ComponentData {
102 pub fn as_ref(&self) -> ComponentDataRef {
103 ComponentDataRef {
104 component_id: &self.component_id,
105 data: &self.data,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize))]
112pub struct ComponentDataRef<'a> {
113 pub component_id: &'a ComponentId,
114 #[tls_codec(with = "crate::tlspl::bytes")]
115 pub data: &'a [u8],
116}
117
118#[derive(Debug, Default, Clone, PartialEq, Eq)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[cfg_attr(
125 feature = "serde",
126 serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
127)]
128pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
129
130impl ComponentDataMap {
131 fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
132 self.0
133 .get(&C::component_id())
134 .map(|data| C::from_tls_bytes(data))
135 .transpose()
136 }
137
138 fn insert_or_update_component<C: Component>(
139 &mut self,
140 component: &C,
141 ) -> crate::MlsSpecResult<bool> {
142 let component_data = component.to_tls_bytes()?;
144 match self.0.entry(C::component_id()) {
145 std::collections::btree_map::Entry::Vacant(vacant_entry) => {
146 vacant_entry.insert(component_data);
147 Ok(true)
148 }
149 std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
150 *(occupied_entry.get_mut()) = component_data;
151 Ok(false)
152 }
153 }
154 }
155
156 fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
157 self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
158 }
159}
160
161impl tls_codec::Size for ComponentDataMap {
162 fn tls_serialized_len(&self) -> usize {
163 crate::tlspl::tls_serialized_len_as_vlvec(
164 self.iter()
165 .map(|(component_id, data)| {
166 ComponentDataRef { component_id, data }.tls_serialized_len()
167 })
168 .sum(),
169 )
170 }
171}
172
173impl tls_codec::Deserialize for ComponentDataMap {
174 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
175 where
176 Self: Sized,
177 {
178 let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
179
180 Ok(Self(BTreeMap::from_iter(
181 tlspl_value
182 .into_iter()
183 .map(|cdata| (cdata.component_id, cdata.data)),
184 )))
185 }
186}
187
188impl tls_codec::Serialize for ComponentDataMap {
189 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
190 self.iter()
192 .map(|(component_id, data)| ComponentDataRef { component_id, data })
193 .collect::<Vec<ComponentDataRef>>()
194 .tls_serialize(writer)
195 }
196}
197
198impl std::ops::Deref for ComponentDataMap {
199 type Target = BTreeMap<ComponentId, Vec<u8>>;
200 fn deref(&self) -> &Self::Target {
201 &self.0
202 }
203}
204
205impl std::ops::DerefMut for ComponentDataMap {
206 fn deref_mut(&mut self) -> &mut Self::Target {
207 &mut self.0
208 }
209}
210
211impl From<Vec<ComponentData>> for ComponentDataMap {
212 fn from(value: Vec<ComponentData>) -> Self {
213 Self(BTreeMap::from_iter(
214 value
215 .into_iter()
216 .map(|component| (component.component_id, component.data)),
217 ))
218 }
219}
220
221#[allow(clippy::from_over_into)]
222impl Into<Vec<ComponentData>> for ComponentDataMap {
223 fn into(self) -> Vec<ComponentData> {
224 self.0
225 .into_iter()
226 .map(|(component_id, data)| ComponentData { component_id, data })
227 .collect()
228 }
229}
230
231#[derive(
236 Debug,
237 Default,
238 Clone,
239 PartialEq,
240 Eq,
241 tls_codec::TlsSize,
242 tls_codec::TlsDeserialize,
243 tls_codec::TlsSerialize,
244)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub struct ApplicationDataDictionary {
247 pub component_data: ComponentDataMap,
248}
249
250impl ApplicationDataDictionary {
251 pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef> {
252 self.component_data
253 .iter()
254 .map(|(component_id, data)| ComponentDataRef { component_id, data })
255 }
256
257 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
258 self.component_data.extract_component::<C>()
259 }
260
261 pub fn insert_or_update_component<C: Component>(
263 &mut self,
264 component: &C,
265 ) -> crate::MlsSpecResult<bool> {
266 self.component_data.insert_or_update_component(component)
267 }
268
269 pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
274 match update.op {
275 ApplicationDataUpdateOperation::Update { update: data } => {
276 *self.component_data.entry(update.component_id).or_default() = data;
277 true
278 }
279 ApplicationDataUpdateOperation::Remove => {
280 self.component_data.remove(&update.component_id).is_some()
281 }
282 }
283 }
284}
285
286impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
287 fn from(val: ApplicationDataDictionary) -> Self {
288 crate::group::extensions::Extension::ApplicationData(val)
289 }
290}
291
292#[derive(
293 Debug,
294 Clone,
295 PartialEq,
296 Eq,
297 tls_codec::TlsSerialize,
298 tls_codec::TlsDeserialize,
299 tls_codec::TlsSize,
300)]
301#[repr(u8)]
302#[cfg_attr(
303 feature = "serde",
304 derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
305)]
306pub enum ApplicationDataUpdateOperationType {
307 Invalid = 0x00,
308 Update = 0x01,
309 Remove = 0x02,
310}
311
312#[derive(
313 Debug,
314 Clone,
315 PartialEq,
316 Eq,
317 tls_codec::TlsSerialize,
318 tls_codec::TlsDeserialize,
319 tls_codec::TlsSize,
320)]
321#[repr(u8)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323pub enum ApplicationDataUpdateOperation {
324 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
325 Update {
326 #[tls_codec(with = "crate::tlspl::bytes")]
327 update: Vec<u8>,
328 },
329 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
330 Remove,
331}
332
333#[derive(
334 Debug,
335 Clone,
336 PartialEq,
337 Eq,
338 tls_codec::TlsSerialize,
339 tls_codec::TlsDeserialize,
340 tls_codec::TlsSize,
341)]
342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
343pub struct AppDataUpdate {
344 pub component_id: ComponentId,
345 pub op: ApplicationDataUpdateOperation,
346}
347
348impl AppDataUpdate {
349 pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
354 let type_component_id = C::component_id();
355 if type_component_id != self.component_id {
356 return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
357 expected: type_component_id,
358 actual: self.component_id,
359 });
360 }
361
362 let ApplicationDataUpdateOperation::Update { update } = &self.op else {
363 return Ok(None);
364 };
365
366 Ok(Some(C::from_tls_bytes(update)?))
367 }
368}
369
370#[derive(
371 Debug,
372 Clone,
373 PartialEq,
374 Eq,
375 tls_codec::TlsSerialize,
376 tls_codec::TlsDeserialize,
377 tls_codec::TlsSize,
378)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct ApplicationData {
381 pub component_id: ComponentId,
382 #[tls_codec(with = "crate::tlspl::bytes")]
383 pub data: Vec<u8>,
384}
385
386pub type AppEphemeral = ApplicationData;
387
388#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize))]
390#[cfg_attr(feature = "serde", serde(transparent))]
391pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
392
393impl<'a> SafeAadItemRef<'a> {
394 pub fn component_id(&self) -> &ComponentId {
395 self.0.component_id
396 }
397
398 pub fn aad_item_data(&self) -> &[u8] {
399 self.0.data
400 }
401
402 pub fn from_item_data<C: Component>(
403 component_id: &'a ComponentId,
404 aad_item_data: &'a [u8],
405 ) -> Option<Self> {
406 (&C::component_id() == component_id).then(|| {
407 SafeAadItemRef(ComponentDataRef {
408 component_id,
409 data: aad_item_data,
410 })
411 })
412 }
413}
414
415#[derive(
416 Debug,
417 Clone,
418 PartialEq,
419 Eq,
420 tls_codec::TlsSerialize,
421 tls_codec::TlsDeserialize,
422 tls_codec::TlsSize,
423)]
424#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
425#[cfg_attr(feature = "serde", serde(transparent))]
426pub struct SafeAadItem(ComponentData);
427
428impl SafeAadItem {
429 pub fn as_ref(&self) -> SafeAadItemRef {
430 SafeAadItemRef(self.0.as_ref())
431 }
432}
433
434#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
435#[cfg_attr(feature = "serde", derive(serde::Serialize))]
436pub struct SafeAadRef<'a> {
437 pub aad_items: &'a [&'a SafeAadItemRef<'a>],
438}
439
440impl SafeAadRef<'_> {
441 pub fn is_ordered_and_unique(&self) -> bool {
442 let mut iter = self.aad_items.iter().peekable();
443
444 while let Some(item) = iter.next() {
445 let Some(next) = iter.peek() else {
446 continue;
447 };
448
449 if item.component_id() >= next.component_id() {
450 return false;
451 }
452 }
453
454 true
455 }
456}
457
458impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
459 fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
460 Self { aad_items }
461 }
462}
463
464#[derive(
465 Debug,
466 Default,
467 Clone,
468 PartialEq,
469 Eq,
470 tls_codec::TlsSerialize,
471 tls_codec::TlsDeserialize,
472 tls_codec::TlsSize,
473)]
474#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
475pub struct SafeAad {
476 aad_items: ComponentDataMap,
477}
478
479impl SafeAad {
480 pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef> {
481 self.aad_items
482 .iter()
483 .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
484 }
485
486 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
487 self.aad_items.extract_component::<C>()
488 }
489
490 pub fn insert_or_update_component<C: Component>(
492 &mut self,
493 component: &C,
494 ) -> crate::MlsSpecResult<bool> {
495 self.aad_items.insert_or_update_component(component)
496 }
497}
498
499#[derive(
500 Debug,
501 Clone,
502 PartialEq,
503 Eq,
504 tls_codec::TlsSerialize,
505 tls_codec::TlsDeserialize,
506 tls_codec::TlsSize,
507)]
508#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
509pub struct WireFormats {
510 pub wire_formats: Vec<crate::defs::WireFormat>,
511}
512
513#[derive(
514 Debug,
515 Clone,
516 PartialEq,
517 Eq,
518 tls_codec::TlsSerialize,
519 tls_codec::TlsDeserialize,
520 tls_codec::TlsSize,
521)]
522#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
523pub struct ComponentsList {
524 pub component_ids: Vec<ComponentId>,
525}
526
527#[derive(
528 Debug,
529 Clone,
530 PartialEq,
531 Eq,
532 tls_codec::TlsSerialize,
533 tls_codec::TlsDeserialize,
534 tls_codec::TlsSize,
535)]
536pub struct AppComponents(pub ComponentsList);
537
538impl Component for AppComponents {
539 fn component_id() -> ComponentId {
540 super::APP_COMPONENTS_ID
541 }
542}
543
544#[derive(
545 Debug,
546 Clone,
547 PartialEq,
548 Eq,
549 tls_codec::TlsSerialize,
550 tls_codec::TlsDeserialize,
551 tls_codec::TlsSize,
552)]
553pub struct SafeAadComponent(pub ComponentsList);
554
555impl Component for SafeAadComponent {
556 fn component_id() -> ComponentId {
557 super::SAFE_AAD_ID
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use std::collections::BTreeMap;
564
565 use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
566 use crate::{
567 drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
568 generate_roundtrip_test,
569 };
570
571 generate_roundtrip_test!(can_roundtrip_appdatadict, {
572 ApplicationDataDictionary {
573 component_data: super::ComponentDataMap(BTreeMap::from([
574 (1, vec![1]),
575 (3, vec![3]),
576 (2, vec![2]),
577 ])),
578 }
579 });
580
581 generate_roundtrip_test!(can_roundtrip_safeaad, {
582 SafeAad {
583 aad_items: super::ComponentDataMap(BTreeMap::from([
584 (1, vec![1]),
585 (3, vec![3]),
586 (2, vec![2]),
587 ])),
588 }
589 });
590
591 #[test]
592 fn can_build_safe_aad() {
593 let mut safe_aad = SafeAad::default();
594 safe_aad
595 .insert_or_update_component(&LastResortKeyPackage)
596 .unwrap();
597
598 let cid = LastResortKeyPackage::component_id();
599 let aad_item_ref =
600 SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
601
602 let items = &[&aad_item_ref];
603 let safe_ref = SafeAadRef::from(items.as_slice());
604 assert!(safe_ref.is_ordered_and_unique());
605 }
606}