mls_spec/drafts/mls_extensions/
safe_application.rs1use std::collections::BTreeMap;
2
3use crate::{SensitiveBytes, key_schedule::PreSharedKeyId};
4
5pub type ComponentId = u16;
6
7pub const COMPONENT_ID_GREASE_VALUES: [ComponentId; 8] = [
8 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A,
9];
10
11pub trait Component: crate::Parsable + crate::Serializable {
12 fn component_id() -> ComponentId;
13
14 fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
15 PreSharedKeyId {
16 psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
17 crate::key_schedule::ApplicationPsk {
18 component_id: Self::component_id(),
19 psk_id,
20 },
21 ),
22 psk_nonce,
23 }
24 }
25
26 fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
27 Ok(ComponentData {
28 component_id: Self::component_id(),
29 data: self.to_tls_bytes()?,
30 })
31 }
32}
33
34#[derive(
35 Debug, Clone, Default, PartialEq, Eq, strum::IntoStaticStr, strum::EnumString, strum::Display,
36)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[repr(u8)]
39pub enum ComponentOperationBaseLabel {
40 #[default]
41 #[strum(serialize = "MLS Component")]
42 MlsComponent,
43 Custom(String),
45}
46
47impl tls_codec::Size for ComponentOperationBaseLabel {
48 fn tls_serialized_len(&self) -> usize {
49 crate::tlspl::string::tls_serialized_len(self.into())
50 }
51}
52
53impl tls_codec::Serialize for ComponentOperationBaseLabel {
54 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
55 crate::tlspl::string::tls_serialize(self.into(), writer)
56 }
57}
58
59impl tls_codec::Deserialize for ComponentOperationBaseLabel {
60 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
61 where
62 Self: Sized,
63 {
64 let raw_str = crate::tlspl::string::tls_deserialize(bytes)?;
65 Ok(Self::try_from(raw_str.as_str()).unwrap_or(Self::Custom(raw_str)))
66 }
67}
68
69#[derive(
70 Debug,
71 Clone,
72 PartialEq,
73 Eq,
74 tls_codec::TlsSerialize,
75 tls_codec::TlsDeserialize,
76 tls_codec::TlsSize,
77)]
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79pub struct ComponentOperationLabel {
80 pub base_label: ComponentOperationBaseLabel,
81 pub component_id: ComponentId,
82 #[tls_codec(with = "crate::tlspl::bytes")]
83 pub label: Vec<u8>,
84}
85
86#[derive(
87 Debug,
88 Clone,
89 PartialEq,
90 Eq,
91 Hash,
92 tls_codec::TlsSerialize,
93 tls_codec::TlsDeserialize,
94 tls_codec::TlsSize,
95)]
96#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
97pub struct ComponentData {
98 pub component_id: ComponentId,
99 #[tls_codec(with = "crate::tlspl::bytes")]
100 pub data: Vec<u8>,
101}
102
103impl ComponentData {
104 pub fn as_ref(&self) -> ComponentDataRef<'_> {
105 ComponentDataRef {
106 component_id: &self.component_id,
107 data: &self.data,
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize))]
114pub struct ComponentDataRef<'a> {
115 pub component_id: &'a ComponentId,
116 #[tls_codec(with = "crate::tlspl::bytes")]
117 pub data: &'a [u8],
118}
119
120#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126#[cfg_attr(
127 feature = "serde",
128 serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
129)]
130pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
131
132impl ComponentDataMap {
133 fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
134 self.0
135 .get(&C::component_id())
136 .map(|data| C::from_tls_bytes(data))
137 .transpose()
138 }
139
140 fn insert_or_update_component<C: Component>(
141 &mut self,
142 component: &C,
143 ) -> crate::MlsSpecResult<bool> {
144 let component_data = component.to_tls_bytes()?;
146 match self.0.entry(C::component_id()) {
147 std::collections::btree_map::Entry::Vacant(vacant_entry) => {
148 vacant_entry.insert(component_data);
149 Ok(true)
150 }
151 std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
152 *(occupied_entry.get_mut()) = component_data;
153 Ok(false)
154 }
155 }
156 }
157
158 fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
159 self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
160 }
161}
162
163impl tls_codec::Size for ComponentDataMap {
164 fn tls_serialized_len(&self) -> usize {
165 crate::tlspl::tls_serialized_len_as_vlvec(
166 self.iter()
167 .map(|(component_id, data)| {
168 ComponentDataRef { component_id, data }.tls_serialized_len()
169 })
170 .sum(),
171 )
172 }
173}
174
175impl tls_codec::Deserialize for ComponentDataMap {
176 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
177 where
178 Self: Sized,
179 {
180 let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
181
182 Ok(Self(BTreeMap::from_iter(
183 tlspl_value
184 .into_iter()
185 .map(|cdata| (cdata.component_id, cdata.data)),
186 )))
187 }
188}
189
190impl tls_codec::Serialize for ComponentDataMap {
191 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
192 self.iter()
194 .map(|(component_id, data)| ComponentDataRef { component_id, data })
195 .collect::<Vec<ComponentDataRef>>()
196 .tls_serialize(writer)
197 }
198}
199
200impl std::ops::Deref for ComponentDataMap {
201 type Target = BTreeMap<ComponentId, Vec<u8>>;
202 fn deref(&self) -> &Self::Target {
203 &self.0
204 }
205}
206
207impl std::ops::DerefMut for ComponentDataMap {
208 fn deref_mut(&mut self) -> &mut Self::Target {
209 &mut self.0
210 }
211}
212
213impl From<Vec<ComponentData>> for ComponentDataMap {
214 fn from(value: Vec<ComponentData>) -> Self {
215 Self(BTreeMap::from_iter(
216 value
217 .into_iter()
218 .map(|component| (component.component_id, component.data)),
219 ))
220 }
221}
222
223#[allow(clippy::from_over_into)]
224impl Into<Vec<ComponentData>> for ComponentDataMap {
225 fn into(self) -> Vec<ComponentData> {
226 self.0
227 .into_iter()
228 .map(|(component_id, data)| ComponentData { component_id, data })
229 .collect()
230 }
231}
232
233#[derive(
238 Debug,
239 Default,
240 Clone,
241 PartialEq,
242 Eq,
243 Hash,
244 tls_codec::TlsSize,
245 tls_codec::TlsDeserialize,
246 tls_codec::TlsSerialize,
247)]
248#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
249pub struct ApplicationDataDictionary {
250 pub component_data: ComponentDataMap,
251}
252
253impl ApplicationDataDictionary {
254 pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef<'_>> {
255 self.component_data
256 .iter()
257 .map(|(component_id, data)| ComponentDataRef { component_id, data })
258 }
259
260 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
261 self.component_data.extract_component::<C>()
262 }
263
264 pub fn insert_or_update_component<C: Component>(
266 &mut self,
267 component: &C,
268 ) -> crate::MlsSpecResult<bool> {
269 self.component_data.insert_or_update_component(component)
270 }
271
272 pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
277 match update.op {
278 ApplicationDataUpdateOperation::Update { update: data } => {
279 *self.component_data.entry(update.component_id).or_default() = data;
280 true
281 }
282 ApplicationDataUpdateOperation::Remove => {
283 self.component_data.remove(&update.component_id).is_some()
284 }
285 }
286 }
287}
288
289impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
290 fn from(val: ApplicationDataDictionary) -> Self {
291 crate::group::extensions::Extension::ApplicationData(val)
292 }
293}
294
295#[derive(
296 Debug,
297 Clone,
298 PartialEq,
299 Eq,
300 tls_codec::TlsSerialize,
301 tls_codec::TlsDeserialize,
302 tls_codec::TlsSize,
303)]
304#[repr(u8)]
305#[cfg_attr(
306 feature = "serde",
307 derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
308)]
309pub enum ApplicationDataUpdateOperationType {
310 Invalid = 0x00,
311 Update = 0x01,
312 Remove = 0x02,
313}
314
315#[derive(
316 Debug,
317 Clone,
318 PartialEq,
319 Eq,
320 tls_codec::TlsSerialize,
321 tls_codec::TlsDeserialize,
322 tls_codec::TlsSize,
323)]
324#[repr(u8)]
325#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
326pub enum ApplicationDataUpdateOperation {
327 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
328 Update {
329 #[tls_codec(with = "crate::tlspl::bytes")]
330 update: Vec<u8>,
331 },
332 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
333 Remove,
334}
335
336#[derive(
337 Debug,
338 Clone,
339 PartialEq,
340 Eq,
341 tls_codec::TlsSerialize,
342 tls_codec::TlsDeserialize,
343 tls_codec::TlsSize,
344)]
345#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
346pub struct AppDataUpdate {
347 pub component_id: ComponentId,
348 pub op: ApplicationDataUpdateOperation,
349}
350
351impl AppDataUpdate {
352 pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
357 let type_component_id = C::component_id();
358 if type_component_id != self.component_id {
359 return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
360 expected: type_component_id,
361 actual: self.component_id,
362 });
363 }
364
365 let ApplicationDataUpdateOperation::Update { update } = &self.op else {
366 return Ok(None);
367 };
368
369 Ok(Some(C::from_tls_bytes(update)?))
370 }
371}
372
373#[derive(
374 Debug,
375 Clone,
376 PartialEq,
377 Eq,
378 tls_codec::TlsSerialize,
379 tls_codec::TlsDeserialize,
380 tls_codec::TlsSize,
381)]
382#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
383pub struct ApplicationData {
384 pub component_id: ComponentId,
385 #[tls_codec(with = "crate::tlspl::bytes")]
386 pub data: Vec<u8>,
387}
388
389pub type AppEphemeral = ApplicationData;
390
391#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
392#[cfg_attr(feature = "serde", derive(serde::Serialize))]
393#[cfg_attr(feature = "serde", serde(transparent))]
394pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
395
396impl<'a> SafeAadItemRef<'a> {
397 pub fn component_id(&self) -> &ComponentId {
398 self.0.component_id
399 }
400
401 pub fn aad_item_data(&self) -> &[u8] {
402 self.0.data
403 }
404
405 pub fn from_item_data<C: Component>(
406 component_id: &'a ComponentId,
407 aad_item_data: &'a [u8],
408 ) -> Option<Self> {
409 (&C::component_id() == component_id).then_some(SafeAadItemRef(ComponentDataRef {
410 component_id,
411 data: aad_item_data,
412 }))
413 }
414}
415
416#[derive(
417 Debug,
418 Clone,
419 PartialEq,
420 Eq,
421 Hash,
422 tls_codec::TlsSerialize,
423 tls_codec::TlsDeserialize,
424 tls_codec::TlsSize,
425)]
426#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
427#[cfg_attr(feature = "serde", serde(transparent))]
428pub struct SafeAadItem(ComponentData);
429
430impl SafeAadItem {
431 pub fn as_ref(&self) -> SafeAadItemRef<'_> {
432 SafeAadItemRef(self.0.as_ref())
433 }
434}
435
436#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
437#[cfg_attr(feature = "serde", derive(serde::Serialize))]
438pub struct SafeAadRef<'a> {
439 pub aad_items: &'a [&'a SafeAadItemRef<'a>],
440}
441
442impl SafeAadRef<'_> {
443 pub fn is_ordered_and_unique(&self) -> bool {
444 let mut iter = self.aad_items.iter().peekable();
445
446 while let Some(item) = iter.next() {
447 let Some(next) = iter.peek() else {
448 continue;
449 };
450
451 if item.component_id() >= next.component_id() {
452 return false;
453 }
454 }
455
456 true
457 }
458}
459
460impl<'a> From<&'a [&'a SafeAadItemRef<'a>]> for SafeAadRef<'a> {
461 fn from(aad_items: &'a [&'a SafeAadItemRef<'a>]) -> Self {
462 Self { aad_items }
463 }
464}
465
466#[derive(
467 Debug,
468 Default,
469 Clone,
470 PartialEq,
471 Eq,
472 Hash,
473 tls_codec::TlsSerialize,
474 tls_codec::TlsDeserialize,
475 tls_codec::TlsSize,
476)]
477#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
478pub struct SafeAad {
479 aad_items: ComponentDataMap,
480}
481
482impl SafeAad {
483 pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef<'_>> {
484 self.aad_items
485 .iter()
486 .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
487 }
488
489 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
490 self.aad_items.extract_component::<C>()
491 }
492
493 pub fn insert_or_update_component<C: Component>(
495 &mut self,
496 component: &C,
497 ) -> crate::MlsSpecResult<bool> {
498 self.aad_items.insert_or_update_component(component)
499 }
500}
501
502#[derive(
503 Debug,
504 Clone,
505 PartialEq,
506 Eq,
507 Hash,
508 tls_codec::TlsSerialize,
509 tls_codec::TlsDeserialize,
510 tls_codec::TlsSize,
511)]
512#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
513pub struct WireFormats {
514 pub wire_formats: Vec<crate::defs::WireFormat>,
515}
516
517#[derive(
518 Debug,
519 Clone,
520 PartialEq,
521 Eq,
522 tls_codec::TlsSerialize,
523 tls_codec::TlsDeserialize,
524 tls_codec::TlsSize,
525)]
526#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
527pub struct ComponentsList {
528 pub component_ids: Vec<ComponentId>,
529}
530
531#[derive(
532 Debug,
533 Clone,
534 PartialEq,
535 Eq,
536 tls_codec::TlsSerialize,
537 tls_codec::TlsDeserialize,
538 tls_codec::TlsSize,
539)]
540pub struct AppComponents(pub ComponentsList);
541
542impl Component for AppComponents {
543 fn component_id() -> ComponentId {
544 super::APP_COMPONENTS_ID
545 }
546}
547
548#[derive(
549 Debug,
550 Clone,
551 PartialEq,
552 Eq,
553 tls_codec::TlsSerialize,
554 tls_codec::TlsDeserialize,
555 tls_codec::TlsSize,
556)]
557pub struct SafeAadComponent(pub ComponentsList);
558
559impl Component for SafeAadComponent {
560 fn component_id() -> ComponentId {
561 super::SAFE_AAD_ID
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use std::collections::BTreeMap;
568
569 use super::{ApplicationDataDictionary, Component, SafeAad, SafeAadItemRef, SafeAadRef};
570 use crate::{
571 drafts::mls_extensions::last_resort_keypackage::LastResortKeyPackage,
572 generate_roundtrip_test,
573 };
574
575 generate_roundtrip_test!(can_roundtrip_appdatadict, {
576 ApplicationDataDictionary {
577 component_data: super::ComponentDataMap(BTreeMap::from([
578 (1, vec![1]),
579 (3, vec![3]),
580 (2, vec![2]),
581 ])),
582 }
583 });
584
585 generate_roundtrip_test!(can_roundtrip_safeaad, {
586 SafeAad {
587 aad_items: super::ComponentDataMap(BTreeMap::from([
588 (1, vec![1]),
589 (3, vec![3]),
590 (2, vec![2]),
591 ])),
592 }
593 });
594
595 #[test]
596 fn can_build_safe_aad() {
597 let mut safe_aad = SafeAad::default();
598 safe_aad
599 .insert_or_update_component(&LastResortKeyPackage)
600 .unwrap();
601
602 let cid = LastResortKeyPackage::component_id();
603 let aad_item_ref =
604 SafeAadItemRef::from_item_data::<LastResortKeyPackage>(&cid, &[]).unwrap();
605
606 let items = &[&aad_item_ref];
607 let safe_ref = SafeAadRef::from(items.as_slice());
608 assert!(safe_ref.is_ordered_and_unique());
609 }
610}