1use std::collections::BTreeMap;
2
3use crate::{
4 SensitiveBytes, ToPrefixedLabel,
5 defs::{
6 ProtocolVersion,
7 labels::{PublicKeyEncryptionLabel, SignatureLabel},
8 },
9 key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15 fn component_id() -> ComponentId;
16
17 fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18 PreSharedKeyId {
19 psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20 crate::key_schedule::ApplicationPsk {
21 component_id: Self::component_id(),
22 psk_id,
23 },
24 ),
25 psk_nonce,
26 }
27 }
28
29 fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30 Ok(ComponentData {
31 component_id: Self::component_id(),
32 data: self.to_tls_bytes()?,
33 })
34 }
35}
36
37#[derive(
38 Debug,
39 Clone,
40 PartialEq,
41 Eq,
42 tls_codec::TlsSerialize,
43 tls_codec::TlsDeserialize,
44 tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48 #[tls_codec(with = "crate::tlspl::bytes")]
49 label: Vec<u8>,
50 pub component_id: ComponentId,
51 #[tls_codec(with = "crate::tlspl::bytes")]
52 pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56 pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58 Self {
59 label: PublicKeyEncryptionLabel::SafeApp
60 .to_prefixed_string(ProtocolVersion::default())
61 .into_bytes(),
62 component_id,
63 context,
64 }
65 }
66
67 pub fn new_for_signature(
69 label: SignatureLabel,
70 component_id: ComponentId,
71 context: Vec<u8>,
72 ) -> (Self, SignatureLabel) {
73 let col = Self {
74 label: label
75 .to_prefixed_string(ProtocolVersion::default())
76 .into_bytes(),
77 component_id,
78 context,
79 };
80
81 (col, SignatureLabel::ComponentOperationLabel)
82 }
83}
84
85#[derive(
86 Debug,
87 Clone,
88 PartialEq,
89 Eq,
90 Hash,
91 tls_codec::TlsSerialize,
92 tls_codec::TlsDeserialize,
93 tls_codec::TlsSize,
94)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96pub struct ComponentData {
97 pub component_id: ComponentId,
98 #[tls_codec(with = "crate::tlspl::bytes")]
99 pub data: Vec<u8>,
100}
101
102impl ComponentData {
103 pub fn as_ref(&self) -> ComponentDataRef<'_> {
104 ComponentDataRef {
105 component_id: &self.component_id,
106 data: &self.data,
107 }
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize))]
113pub struct ComponentDataRef<'a> {
114 pub component_id: &'a ComponentId,
115 #[tls_codec(with = "crate::tlspl::bytes")]
116 pub data: &'a [u8],
117}
118
119#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
124#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
125#[cfg_attr(
126 feature = "serde",
127 serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
128)]
129pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
130
131impl ComponentDataMap {
132 fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
133 self.0
134 .get(&C::component_id())
135 .map(|data| C::from_tls_bytes(data))
136 .transpose()
137 }
138
139 fn insert_or_update_component<C: Component>(
140 &mut self,
141 component: &C,
142 ) -> crate::MlsSpecResult<bool> {
143 let component_data = component.to_tls_bytes()?;
145 match self.0.entry(C::component_id()) {
146 std::collections::btree_map::Entry::Vacant(vacant_entry) => {
147 vacant_entry.insert(component_data);
148 Ok(true)
149 }
150 std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
151 *(occupied_entry.get_mut()) = component_data;
152 Ok(false)
153 }
154 }
155 }
156
157 fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
158 self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
159 }
160}
161
162impl tls_codec::Size for ComponentDataMap {
163 fn tls_serialized_len(&self) -> usize {
164 crate::tlspl::tls_serialized_len_as_vlvec(
165 self.iter()
166 .map(|(component_id, data)| {
167 ComponentDataRef { component_id, data }.tls_serialized_len()
168 })
169 .sum(),
170 )
171 }
172}
173
174impl tls_codec::Deserialize for ComponentDataMap {
175 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
176 where
177 Self: Sized,
178 {
179 let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
180
181 Ok(Self(BTreeMap::from_iter(
182 tlspl_value
183 .into_iter()
184 .map(|cdata| (cdata.component_id, cdata.data)),
185 )))
186 }
187}
188
189impl tls_codec::Serialize for ComponentDataMap {
190 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
191 self.iter()
193 .map(|(component_id, data)| ComponentDataRef { component_id, data })
194 .collect::<Vec<ComponentDataRef>>()
195 .tls_serialize(writer)
196 }
197}
198
199impl std::ops::Deref for ComponentDataMap {
200 type Target = BTreeMap<ComponentId, Vec<u8>>;
201 fn deref(&self) -> &Self::Target {
202 &self.0
203 }
204}
205
206impl std::ops::DerefMut for ComponentDataMap {
207 fn deref_mut(&mut self) -> &mut Self::Target {
208 &mut self.0
209 }
210}
211
212impl From<Vec<ComponentData>> for ComponentDataMap {
213 fn from(value: Vec<ComponentData>) -> Self {
214 Self(BTreeMap::from_iter(
215 value
216 .into_iter()
217 .map(|component| (component.component_id, component.data)),
218 ))
219 }
220}
221
222#[allow(clippy::from_over_into)]
223impl Into<Vec<ComponentData>> for ComponentDataMap {
224 fn into(self) -> Vec<ComponentData> {
225 self.0
226 .into_iter()
227 .map(|(component_id, data)| ComponentData { component_id, data })
228 .collect()
229 }
230}
231
232#[derive(
237 Debug,
238 Default,
239 Clone,
240 PartialEq,
241 Eq,
242 Hash,
243 tls_codec::TlsSize,
244 tls_codec::TlsDeserialize,
245 tls_codec::TlsSerialize,
246)]
247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
248pub struct ApplicationDataDictionary {
249 pub component_data: ComponentDataMap,
250}
251
252impl ApplicationDataDictionary {
253 pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
254 self.component_data
255 .iter()
256 .map(|(component_id, data)| ComponentDataRef { component_id, data })
257 }
258
259 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
260 self.component_data.extract_component::<C>()
261 }
262
263 pub fn insert_or_update_component<C: Component>(
265 &mut self,
266 component: &C,
267 ) -> crate::MlsSpecResult<bool> {
268 self.component_data.insert_or_update_component(component)
269 }
270
271 pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
276 match update.op {
277 ApplicationDataUpdateOperation::Update { update: data } => {
278 *self.component_data.entry(update.component_id).or_default() = data;
279 true
280 }
281 ApplicationDataUpdateOperation::Remove => {
282 self.component_data.remove(&update.component_id).is_some()
283 }
284 }
285 }
286}
287
288impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
289 fn from(val: ApplicationDataDictionary) -> Self {
290 crate::group::extensions::Extension::ApplicationData(val)
291 }
292}
293
294#[derive(
295 Debug,
296 Clone,
297 PartialEq,
298 Eq,
299 tls_codec::TlsSerialize,
300 tls_codec::TlsDeserialize,
301 tls_codec::TlsSize,
302)]
303#[repr(u8)]
304#[cfg_attr(
305 feature = "serde",
306 derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
307)]
308pub enum ApplicationDataUpdateOperationType {
309 Invalid = 0x00,
310 Update = 0x01,
311 Remove = 0x02,
312}
313
314#[derive(
315 Debug,
316 Clone,
317 PartialEq,
318 Eq,
319 tls_codec::TlsSerialize,
320 tls_codec::TlsDeserialize,
321 tls_codec::TlsSize,
322)]
323#[repr(u8)]
324#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
325pub enum ApplicationDataUpdateOperation {
326 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
327 Update {
328 #[tls_codec(with = "crate::tlspl::bytes")]
329 update: Vec<u8>,
330 },
331 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
332 Remove,
333}
334
335#[derive(
336 Debug,
337 Clone,
338 PartialEq,
339 Eq,
340 tls_codec::TlsSerialize,
341 tls_codec::TlsDeserialize,
342 tls_codec::TlsSize,
343)]
344#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
345pub struct AppDataUpdate {
346 pub component_id: ComponentId,
347 pub op: ApplicationDataUpdateOperation,
348}
349
350impl AppDataUpdate {
351 pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
356 let type_component_id = C::component_id();
357 if type_component_id != self.component_id {
358 return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
359 expected: type_component_id,
360 actual: self.component_id,
361 });
362 }
363
364 let ApplicationDataUpdateOperation::Update { update } = &self.op else {
365 return Ok(None);
366 };
367
368 Ok(Some(C::from_tls_bytes(update)?))
369 }
370}
371
372#[derive(
373 Debug,
374 Clone,
375 PartialEq,
376 Eq,
377 tls_codec::TlsSerialize,
378 tls_codec::TlsDeserialize,
379 tls_codec::TlsSize,
380)]
381#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
382pub struct ApplicationData {
383 pub component_id: ComponentId,
384 #[tls_codec(with = "crate::tlspl::bytes")]
385 pub data: Vec<u8>,
386}
387
388pub type AppEphemeral = ApplicationData;
389
390#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
391#[cfg_attr(feature = "serde", derive(serde::Serialize))]
392#[cfg_attr(feature = "serde", serde(transparent))]
393pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
394
395impl<'a> SafeAadItemRef<'a> {
396 pub fn component_id(&self) -> &ComponentId {
397 self.0.component_id
398 }
399
400 pub fn aad_item_data(&self) -> &[u8] {
401 self.0.data
402 }
403
404 pub fn from_item_data<C: Component>(
405 component_id: &'a ComponentId,
406 aad_item_data: &'a [u8],
407 ) -> Option<Self> {
408 (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
409 component_id,
410 data: aad_item_data,
411 }))
412 }
413}
414
415#[derive(
416 Debug,
417 Clone,
418 PartialEq,
419 Eq,
420 Hash,
421 tls_codec::TlsSerialize,
422 tls_codec::TlsDeserialize,
423 tls_codec::TlsSize,
424)]
425#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
426#[cfg_attr(feature = "serde", serde(transparent))]
427pub struct SafeAadItem(ComponentData);
428
429impl SafeAadItem {
430 pub fn as_ref(&self) -> SafeAadItemRef<'_> {
431 SafeAadItemRef(self.0.as_ref())
432 }
433}
434
435#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
436#[cfg_attr(feature = "serde", derive(serde::Serialize))]
437pub struct SafeAadRef<'a> {
438 pub aad_items: &'a [&'a SafeAadItemRef<'a>],
439}
440
441impl SafeAadRef<'_> {
442 pub fn is_ordered_and_unique(&self) -> bool {
443 let mut iter = self.aad_items.iter().peekable();
444
445 while let Some(item) = iter.next() {
446 let Some(next) = iter.peek() else {
447 continue;
448 };
449
450 if item.component_id() >= next.component_id() {
451 return false;
452 }
453 }
454
455 true
456 }
457}
458
459impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
460 fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
461 Self { aad_items }
462 }
463}
464
465#[derive(
466 Debug,
467 Default,
468 Clone,
469 PartialEq,
470 Eq,
471 Hash,
472 tls_codec::TlsSerialize,
473 tls_codec::TlsDeserialize,
474 tls_codec::TlsSize,
475)]
476#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
477pub struct SafeAad {
478 aad_items: ComponentDataMap,
479}
480
481impl SafeAad {
482 pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
483 self.aad_items
484 .iter()
485 .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
486 }
487
488 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
489 self.aad_items.extract_component::<C>()
490 }
491
492 pub fn insert_or_update_component<C: Component>(
494 &mut self,
495 component: &C,
496 ) -> crate::MlsSpecResult<bool> {
497 self.aad_items.insert_or_update_component(component)
498 }
499}
500
501#[derive(
502 Debug,
503 Clone,
504 PartialEq,
505 Eq,
506 Hash,
507 tls_codec::TlsSerialize,
508 tls_codec::TlsDeserialize,
509 tls_codec::TlsSize,
510)]
511#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
512pub struct WireFormats {
513 pub wire_formats: Vec<crate::defs::WireFormat>,
514}
515
516#[derive(
517 Debug,
518 Clone,
519 PartialEq,
520 Eq,
521 tls_codec::TlsSerialize,
522 tls_codec::TlsDeserialize,
523 tls_codec::TlsSize,
524)]
525#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
526pub struct ComponentsList {
527 pub component_ids: Vec<ComponentId>,
528}
529
530#[derive(
531 Debug,
532 Clone,
533 PartialEq,
534 Eq,
535 tls_codec::TlsSerialize,
536 tls_codec::TlsDeserialize,
537 tls_codec::TlsSize,
538)]
539pub struct AppComponents(pub ComponentsList);
540
541impl Component for AppComponents {
542 fn component_id() -> ComponentId {
543 super::APP_COMPONENTS_ID
544 }
545}
546
547#[derive(
548 Debug,
549 Clone,
550 PartialEq,
551 Eq,
552 tls_codec::TlsSerialize,
553 tls_codec::TlsDeserialize,
554 tls_codec::TlsSize,
555)]
556pub struct SafeAadComponent(pub ComponentsList);
557
558impl Component for SafeAadComponent {
559 fn component_id() -> ComponentId {
560 super::SAFE_AAD_ID
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use std::collections::BTreeMap;
567
568 use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
569 use crate::{
570 drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
571 generate_roundtrip_test,
572 };
573
574 generate_roundtrip_test!(can_roundtrip_appdatadict, {
575 ApplicationDataDictionary {
576 component_data: super::ComponentDataMap(BTreeMap::from([
577 (1, vec![1]),
578 (3, vec![3]),
579 (2, vec![2]),
580 ])),
581 }
582 });
583
584 generate_roundtrip_test!(can_roundtrip_safeaad, {
585 SafeAad {
586 aad_items: super::ComponentDataMap(BTreeMap::from([
587 (1, vec![1]),
588 (3, vec![3]),
589 (2, vec![2]),
590 ])),
591 }
592 });
593
594 #[test]
595 fn can_build_safe_aad() {
596 let mut safe_aad = SafeAad::default();
597 safe_aad
598 .insert_or_update_component(&LastResortKeyPackage)
599 .unwrap();
600
601 let cid = LastResortKeyPackage::component_id();
602 let aad_item_ref =
603 SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
604
605 let items = &[&aad_item_ref];
606 let safe_ref = SafeAadRef::from(items.as_slice());
607 assert!(safe_ref.is_ordered_and_unique());
608 }
609}