mls_spec/drafts/mls_extensions/
safe_application.rs1use std::collections::BTreeMap;
2
3use crate::{SensitiveBytes, key_schedule::PreSharedKeyId};
4
5pub type ComponentId = u32;
6
7pub const COMPONENT_ID_GREASE_VALUES: [ComponentId; 15] = [
8 0x0000_0A0A,
9 0x0000_1A1A,
10 0x0000_2A2A,
11 0x0000_3A3A,
12 0x0000_4A4A,
13 0x0000_5A5A,
14 0x0000_6A6A,
15 0x0000_7A7A,
16 0x0000_8A8A,
17 0x0000_9A9A,
18 0x0000_AAAA,
19 0x0000_BABA,
20 0x0000_CACA,
21 0x0000_DADA,
22 0x0000_EAEA,
23];
24
25pub trait Component: crate::Parsable + crate::Serializable {
26 fn component_id() -> ComponentId;
27
28 fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
29 PreSharedKeyId {
30 psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
31 crate::key_schedule::ApplicationPsk {
32 component_id: Self::component_id(),
33 psk_id,
34 },
35 ),
36 psk_nonce,
37 }
38 }
39
40 fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
41 Ok(ComponentData {
42 component_id: Self::component_id(),
43 data: self.to_tls_bytes()?,
44 })
45 }
46}
47
48#[derive(
49 Debug,
50 Clone,
51 Copy,
52 Default,
53 PartialEq,
54 Eq,
55 strum::IntoStaticStr,
56 strum::EnumString,
57 strum::Display,
58)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[repr(u8)]
61pub enum ComponentOperationBaseLabel {
62 #[default]
63 Application = 0x00,
64}
65
66impl tls_codec::Size for ComponentOperationBaseLabel {
67 fn tls_serialized_len(&self) -> usize {
68 crate::tlspl::string::tls_serialized_len(self.into())
69 }
70}
71
72impl tls_codec::Serialize for ComponentOperationBaseLabel {
73 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
74 crate::tlspl::string::tls_serialize(self.into(), writer)
75 }
76}
77
78impl tls_codec::Deserialize for ComponentOperationBaseLabel {
79 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
80 where
81 Self: Sized,
82 {
83 <Self as std::str::FromStr>::from_str(&crate::tlspl::string::tls_deserialize(bytes)?)
84 .map_err(|_| {
85 tls_codec::Error::DecodingError(
86 "Unknown Value in ComponentOperationBaseLabel".into(),
87 )
88 })
89 }
90}
91
92#[derive(
93 Debug,
94 Clone,
95 PartialEq,
96 Eq,
97 tls_codec::TlsSerialize,
98 tls_codec::TlsDeserialize,
99 tls_codec::TlsSize,
100)]
101#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
102pub struct ComponentOperationLabel {
103 pub base_label: ComponentOperationBaseLabel,
104 pub component_id: ComponentId,
105 #[tls_codec(with = "crate::tlspl::bytes")]
106 pub label: Vec<u8>,
107}
108
109#[derive(
110 Debug,
111 Clone,
112 PartialEq,
113 Eq,
114 Hash,
115 tls_codec::TlsSerialize,
116 tls_codec::TlsDeserialize,
117 tls_codec::TlsSize,
118)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub struct ComponentData {
121 pub component_id: ComponentId,
122 #[tls_codec(with = "crate::tlspl::bytes")]
123 pub data: Vec<u8>,
124}
125
126impl ComponentData {
127 pub fn as_ref(&self) -> ComponentDataRef<'_> {
128 ComponentDataRef {
129 component_id: &self.component_id,
130 data: &self.data,
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
136#[cfg_attr(feature = "serde", derive(serde::Serialize))]
137pub struct ComponentDataRef<'a> {
138 pub component_id: &'a ComponentId,
139 #[tls_codec(with = "crate::tlspl::bytes")]
140 pub data: &'a [u8],
141}
142
143#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[cfg_attr(
150 feature = "serde",
151 serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
152)]
153pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
154
155impl ComponentDataMap {
156 fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
157 self.0
158 .get(&C::component_id())
159 .map(|data| C::from_tls_bytes(data))
160 .transpose()
161 }
162
163 fn insert_or_update_component<C: Component>(
164 &mut self,
165 component: &C,
166 ) -> crate::MlsSpecResult<bool> {
167 let component_data = component.to_tls_bytes()?;
169 match self.0.entry(C::component_id()) {
170 std::collections::btree_map::Entry::Vacant(vacant_entry) => {
171 vacant_entry.insert(component_data);
172 Ok(true)
173 }
174 std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
175 *(occupied_entry.get_mut()) = component_data;
176 Ok(false)
177 }
178 }
179 }
180
181 fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
182 self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
183 }
184}
185
186impl tls_codec::Size for ComponentDataMap {
187 fn tls_serialized_len(&self) -> usize {
188 crate::tlspl::tls_serialized_len_as_vlvec(
189 self.iter()
190 .map(|(component_id, data)| {
191 ComponentDataRef { component_id, data }.tls_serialized_len()
192 })
193 .sum(),
194 )
195 }
196}
197
198impl tls_codec::Deserialize for ComponentDataMap {
199 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
200 where
201 Self: Sized,
202 {
203 let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
204
205 Ok(Self(BTreeMap::from_iter(
206 tlspl_value
207 .into_iter()
208 .map(|cdata| (cdata.component_id, cdata.data)),
209 )))
210 }
211}
212
213impl tls_codec::Serialize for ComponentDataMap {
214 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
215 self.iter()
217 .map(|(component_id, data)| ComponentDataRef { component_id, data })
218 .collect::<Vec<ComponentDataRef>>()
219 .tls_serialize(writer)
220 }
221}
222
223impl std::ops::Deref for ComponentDataMap {
224 type Target = BTreeMap<ComponentId, Vec<u8>>;
225 fn deref(&self) -> &Self::Target {
226 &self.0
227 }
228}
229
230impl std::ops::DerefMut for ComponentDataMap {
231 fn deref_mut(&mut self) -> &mut Self::Target {
232 &mut self.0
233 }
234}
235
236impl From<Vec<ComponentData>> for ComponentDataMap {
237 fn from(value: Vec<ComponentData>) -> Self {
238 Self(BTreeMap::from_iter(
239 value
240 .into_iter()
241 .map(|component| (component.component_id, component.data)),
242 ))
243 }
244}
245
246#[allow(clippy::from_over_into)]
247impl Into<Vec<ComponentData>> for ComponentDataMap {
248 fn into(self) -> Vec<ComponentData> {
249 self.0
250 .into_iter()
251 .map(|(component_id, data)| ComponentData { component_id, data })
252 .collect()
253 }
254}
255
256#[derive(
261 Debug,
262 Default,
263 Clone,
264 PartialEq,
265 Eq,
266 Hash,
267 tls_codec::TlsSize,
268 tls_codec::TlsDeserialize,
269 tls_codec::TlsSerialize,
270)]
271#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
272pub struct ApplicationDataDictionary {
273 pub component_data: ComponentDataMap,
274}
275
276impl ApplicationDataDictionary {
277 pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
278 self.component_data
279 .iter()
280 .map(|(component_id, data)| ComponentDataRef { component_id, data })
281 }
282
283 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
284 self.component_data.extract_component::<C>()
285 }
286
287 pub fn insert_or_update_component<C: Component>(
289 &mut self,
290 component: &C,
291 ) -> crate::MlsSpecResult<bool> {
292 self.component_data.insert_or_update_component(component)
293 }
294
295 pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
300 match update.op {
301 ApplicationDataUpdateOperation::Update { update: data } => {
302 *self.component_data.entry(update.component_id).or_default() = data;
303 true
304 }
305 ApplicationDataUpdateOperation::Remove => {
306 self.component_data.remove(&update.component_id).is_some()
307 }
308 }
309 }
310}
311
312impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
313 fn from(val: ApplicationDataDictionary) -> Self {
314 crate::group::extensions::Extension::ApplicationData(val)
315 }
316}
317
318#[derive(
319 Debug,
320 Clone,
321 PartialEq,
322 Eq,
323 tls_codec::TlsSerialize,
324 tls_codec::TlsDeserialize,
325 tls_codec::TlsSize,
326)]
327#[repr(u8)]
328#[cfg_attr(
329 feature = "serde",
330 derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
331)]
332pub enum ApplicationDataUpdateOperationType {
333 Invalid = 0x00,
334 Update = 0x01,
335 Remove = 0x02,
336}
337
338#[derive(
339 Debug,
340 Clone,
341 PartialEq,
342 Eq,
343 tls_codec::TlsSerialize,
344 tls_codec::TlsDeserialize,
345 tls_codec::TlsSize,
346)]
347#[repr(u8)]
348#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
349pub enum ApplicationDataUpdateOperation {
350 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
351 Update {
352 #[tls_codec(with = "crate::tlspl::bytes")]
353 update: Vec<u8>,
354 },
355 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
356 Remove,
357}
358
359#[derive(
360 Debug,
361 Clone,
362 PartialEq,
363 Eq,
364 tls_codec::TlsSerialize,
365 tls_codec::TlsDeserialize,
366 tls_codec::TlsSize,
367)]
368#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
369pub struct AppDataUpdate {
370 pub component_id: ComponentId,
371 pub op: ApplicationDataUpdateOperation,
372}
373
374impl AppDataUpdate {
375 pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
380 let type_component_id = C::component_id();
381 if type_component_id != self.component_id {
382 return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
383 expected: type_component_id,
384 actual: self.component_id,
385 });
386 }
387
388 let ApplicationDataUpdateOperation::Update { update } = &self.op else {
389 return Ok(None);
390 };
391
392 Ok(Some(C::from_tls_bytes(update)?))
393 }
394}
395
396#[derive(
397 Debug,
398 Clone,
399 PartialEq,
400 Eq,
401 tls_codec::TlsSerialize,
402 tls_codec::TlsDeserialize,
403 tls_codec::TlsSize,
404)]
405#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
406pub struct ApplicationData {
407 pub component_id: ComponentId,
408 #[tls_codec(with = "crate::tlspl::bytes")]
409 pub data: Vec<u8>,
410}
411
412pub type AppEphemeral = ApplicationData;
413
414#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
415#[cfg_attr(feature = "serde", derive(serde::Serialize))]
416#[cfg_attr(feature = "serde", serde(transparent))]
417pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
418
419impl<'a> SafeAadItemRef<'a> {
420 pub fn component_id(&self) -> &ComponentId {
421 self.0.component_id
422 }
423
424 pub fn aad_item_data(&self) -> &[u8] {
425 self.0.data
426 }
427
428 pub fn from_item_data<C: Component>(
429 component_id: &'a ComponentId,
430 aad_item_data: &'a [u8],
431 ) -> Option<Self> {
432 (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
433 component_id,
434 data: aad_item_data,
435 }))
436 }
437}
438
439#[derive(
440 Debug,
441 Clone,
442 PartialEq,
443 Eq,
444 Hash,
445 tls_codec::TlsSerialize,
446 tls_codec::TlsDeserialize,
447 tls_codec::TlsSize,
448)]
449#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
450#[cfg_attr(feature = "serde", serde(transparent))]
451pub struct SafeAadItem(ComponentData);
452
453impl SafeAadItem {
454 pub fn as_ref(&self) -> SafeAadItemRef<'_> {
455 SafeAadItemRef(self.0.as_ref())
456 }
457}
458
459#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
460#[cfg_attr(feature = "serde", derive(serde::Serialize))]
461pub struct SafeAadRef<'a> {
462 pub aad_items: &'a [&'a SafeAadItemRef<'a>],
463}
464
465impl SafeAadRef<'_> {
466 pub fn is_ordered_and_unique(&self) -> bool {
467 let mut iter = self.aad_items.iter().peekable();
468
469 while let Some(item) = iter.next() {
470 let Some(next) = iter.peek() else {
471 continue;
472 };
473
474 if item.component_id() >= next.component_id() {
475 return false;
476 }
477 }
478
479 true
480 }
481}
482
483impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
484 fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
485 Self { aad_items }
486 }
487}
488
489#[derive(
490 Debug,
491 Default,
492 Clone,
493 PartialEq,
494 Eq,
495 Hash,
496 tls_codec::TlsSerialize,
497 tls_codec::TlsDeserialize,
498 tls_codec::TlsSize,
499)]
500#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
501pub struct SafeAad {
502 aad_items: ComponentDataMap,
503}
504
505impl SafeAad {
506 pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
507 self.aad_items
508 .iter()
509 .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
510 }
511
512 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
513 self.aad_items.extract_component::<C>()
514 }
515
516 pub fn insert_or_update_component<C: Component>(
518 &mut self,
519 component: &C,
520 ) -> crate::MlsSpecResult<bool> {
521 self.aad_items.insert_or_update_component(component)
522 }
523}
524
525#[derive(
526 Debug,
527 Clone,
528 PartialEq,
529 Eq,
530 Hash,
531 tls_codec::TlsSerialize,
532 tls_codec::TlsDeserialize,
533 tls_codec::TlsSize,
534)]
535#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
536pub struct WireFormats {
537 pub wire_formats: Vec<crate::defs::WireFormat>,
538}
539
540#[derive(
541 Debug,
542 Clone,
543 PartialEq,
544 Eq,
545 tls_codec::TlsSerialize,
546 tls_codec::TlsDeserialize,
547 tls_codec::TlsSize,
548)]
549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
550pub struct ComponentsList {
551 pub component_ids: Vec<ComponentId>,
552}
553
554#[derive(
555 Debug,
556 Clone,
557 PartialEq,
558 Eq,
559 tls_codec::TlsSerialize,
560 tls_codec::TlsDeserialize,
561 tls_codec::TlsSize,
562)]
563pub struct AppComponents(pub ComponentsList);
564
565impl Component for AppComponents {
566 fn component_id() -> ComponentId {
567 super::APP_COMPONENTS_ID
568 }
569}
570
571#[derive(
572 Debug,
573 Clone,
574 PartialEq,
575 Eq,
576 tls_codec::TlsSerialize,
577 tls_codec::TlsDeserialize,
578 tls_codec::TlsSize,
579)]
580pub struct SafeAadComponent(pub ComponentsList);
581
582impl Component for SafeAadComponent {
583 fn component_id() -> ComponentId {
584 super::SAFE_AAD_ID
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use std::collections::BTreeMap;
591
592 use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
593 use crate::{
594 drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
595 generate_roundtrip_test,
596 };
597
598 generate_roundtrip_test!(can_roundtrip_appdatadict, {
599 ApplicationDataDictionary {
600 component_data: super::ComponentDataMap(BTreeMap::from([
601 (1, vec![1]),
602 (3, vec![3]),
603 (2, vec![2]),
604 ])),
605 }
606 });
607
608 generate_roundtrip_test!(can_roundtrip_safeaad, {
609 SafeAad {
610 aad_items: super::ComponentDataMap(BTreeMap::from([
611 (1, vec![1]),
612 (3, vec![3]),
613 (2, vec![2]),
614 ])),
615 }
616 });
617
618 #[test]
619 fn can_build_safe_aad() {
620 let mut safe_aad = SafeAad::default();
621 safe_aad
622 .insert_or_update_component(&LastResortKeyPackage)
623 .unwrap();
624
625 let cid = LastResortKeyPackage::component_id();
626 let aad_item_ref =
627 SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
628
629 let items = &[&aad_item_ref];
630 let safe_ref = SafeAadRef::from(items.as_slice());
631 assert!(safe_ref.is_ordered_and_unique());
632 }
633}