1use std::collections::BTreeMap;
2
3use crate::{
4 SensitiveBytes, ToPrefixedLabel,
5 defs::{
6 ProtocolVersion,
7 labels::{PublicKeyEncryptionLabel, SignatureLabel},
8 },
9 key_schedule::PreSharedKeyId,
10};
11
12pub type ComponentId = u32;
13
14pub trait Component: crate::Parsable + crate::Serializable {
15 fn component_id() -> ComponentId;
16
17 fn psk(psk_id: Vec<u8>, psk_nonce: SensitiveBytes) -> PreSharedKeyId {
18 PreSharedKeyId {
19 psktype: crate::key_schedule::PreSharedKeyIdPskType::Application(
20 crate::key_schedule::ApplicationPsk {
21 component_id: Self::component_id(),
22 psk_id,
23 },
24 ),
25 psk_nonce,
26 }
27 }
28
29 fn to_component_data(&self) -> crate::MlsSpecResult<ComponentData> {
30 Ok(ComponentData {
31 component_id: Self::component_id(),
32 data: self.to_tls_bytes()?,
33 })
34 }
35}
36
37#[derive(
38 Debug,
39 Clone,
40 PartialEq,
41 Eq,
42 tls_codec::TlsSerialize,
43 tls_codec::TlsDeserialize,
44 tls_codec::TlsSize,
45)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct ComponentOperationLabel {
48 #[tls_codec(with = "crate::tlspl::bytes")]
49 label: Vec<u8>,
50 pub component_id: ComponentId,
51 #[tls_codec(with = "crate::tlspl::bytes")]
52 pub context: Vec<u8>,
53}
54
55impl ComponentOperationLabel {
56 pub fn new_with_context_for_hpke(component_id: ComponentId, context: Vec<u8>) -> Self {
58 Self {
59 label: PublicKeyEncryptionLabel::SafeApp
60 .to_prefixed_string(ProtocolVersion::default())
61 .into_bytes(),
62 component_id,
63 context,
64 }
65 }
66
67 pub fn new_for_signature(
69 label: SignatureLabel,
70 component_id: ComponentId,
71 context: Vec<u8>,
72 ) -> (Self, SignatureLabel) {
73 let col = Self {
74 label: label
75 .to_prefixed_string(ProtocolVersion::default())
76 .into_bytes(),
77 component_id,
78 context,
79 };
80
81 (col, SignatureLabel::ComponentOperationLabel)
82 }
83}
84
85#[derive(
86 Debug,
87 Clone,
88 PartialEq,
89 Eq,
90 tls_codec::TlsSerialize,
91 tls_codec::TlsDeserialize,
92 tls_codec::TlsSize,
93)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct ComponentData {
96 pub component_id: ComponentId,
97 #[tls_codec(with = "crate::tlspl::bytes")]
98 pub data: Vec<u8>,
99}
100
101impl ComponentData {
102 pub fn as_ref(&self) -> ComponentDataRef {
103 ComponentDataRef {
104 component_id: &self.component_id,
105 data: &self.data,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize))]
112pub struct ComponentDataRef<'a> {
113 pub component_id: &'a ComponentId,
114 #[tls_codec(with = "crate::tlspl::bytes")]
115 pub data: &'a [u8],
116}
117
118#[derive(Debug, Default, Clone, PartialEq, Eq)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[cfg_attr(
125 feature = "serde",
126 serde(from = "Vec<ComponentData>", into = "Vec<ComponentData>")
127)]
128pub struct ComponentDataMap(BTreeMap<ComponentId, Vec<u8>>);
129
130impl ComponentDataMap {
131 fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
132 self.0
133 .get(&C::component_id())
134 .map(|data| C::from_tls_bytes(data))
135 .transpose()
136 }
137
138 fn insert_or_update_component<C: Component>(
139 &mut self,
140 component: &C,
141 ) -> crate::MlsSpecResult<bool> {
142 let component_data = component.to_tls_bytes()?;
144 match self.0.entry(C::component_id()) {
145 std::collections::btree_map::Entry::Vacant(vacant_entry) => {
146 vacant_entry.insert(component_data);
147 Ok(true)
148 }
149 std::collections::btree_map::Entry::Occupied(mut occupied_entry) => {
150 *(occupied_entry.get_mut()) = component_data;
151 Ok(false)
152 }
153 }
154 }
155
156 fn iter(&self) -> impl Iterator<Item = (&ComponentId, &[u8])> {
157 self.0.iter().map(|(cid, data)| (cid, data.as_slice()))
158 }
159}
160
161impl tls_codec::Size for ComponentDataMap {
162 fn tls_serialized_len(&self) -> usize {
163 crate::tlspl::tls_serialized_len_as_vlvec(
164 self.iter()
165 .map(|(component_id, data)| {
166 ComponentDataRef { component_id, data }.tls_serialized_len()
167 })
168 .sum(),
169 )
170 }
171}
172
173impl tls_codec::Deserialize for ComponentDataMap {
174 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
175 where
176 Self: Sized,
177 {
178 let tlspl_value: Vec<ComponentData> = <_>::tls_deserialize(bytes)?;
179
180 Ok(Self(BTreeMap::from_iter(
181 tlspl_value
182 .into_iter()
183 .map(|cdata| (cdata.component_id, cdata.data)),
184 )))
185 }
186}
187
188impl tls_codec::Serialize for ComponentDataMap {
189 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
190 self.iter()
192 .map(|(component_id, data)| ComponentDataRef { component_id, data })
193 .collect::<Vec<ComponentDataRef>>()
194 .tls_serialize(writer)
195 }
196}
197
198impl std::ops::Deref for ComponentDataMap {
199 type Target = BTreeMap<ComponentId, Vec<u8>>;
200 fn deref(&self) -> &Self::Target {
201 &self.0
202 }
203}
204
205impl std::ops::DerefMut for ComponentDataMap {
206 fn deref_mut(&mut self) -> &mut Self::Target {
207 &mut self.0
208 }
209}
210
211impl From<Vec<ComponentData>> for ComponentDataMap {
212 fn from(value: Vec<ComponentData>) -> Self {
213 Self(BTreeMap::from_iter(
214 value
215 .into_iter()
216 .map(|component| (component.component_id, component.data)),
217 ))
218 }
219}
220
221#[allow(clippy::from_over_into)]
222impl Into<Vec<ComponentData>> for ComponentDataMap {
223 fn into(self) -> Vec<ComponentData> {
224 self.0
225 .into_iter()
226 .map(|(component_id, data)| ComponentData { component_id, data })
227 .collect()
228 }
229}
230
231#[derive(
236 Debug,
237 Default,
238 Clone,
239 PartialEq,
240 Eq,
241 tls_codec::TlsSize,
242 tls_codec::TlsDeserialize,
243 tls_codec::TlsSerialize,
244)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub struct ApplicationDataDictionary {
247 pub component_data: ComponentDataMap,
248}
249
250impl ApplicationDataDictionary {
251 pub fn iter_components(&self) -> impl Iterator<Item = ComponentDataRef> {
252 self.component_data
253 .iter()
254 .map(|(component_id, data)| ComponentDataRef { component_id, data })
255 }
256
257 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
258 self.component_data.extract_component::<C>()
259 }
260
261 pub fn insert_or_update_component<C: Component>(
263 &mut self,
264 component: &C,
265 ) -> crate::MlsSpecResult<bool> {
266 self.component_data.insert_or_update_component(component)
267 }
268
269 pub fn apply_update(&mut self, update: AppDataUpdate) -> bool {
274 match update.op {
275 ApplicationDataUpdateOperation::Update { update: data } => {
276 *self.component_data.entry(update.component_id).or_default() = data;
277 true
278 }
279 ApplicationDataUpdateOperation::Remove => {
280 self.component_data.remove(&update.component_id).is_some()
281 }
282 }
283 }
284}
285
286impl From<ApplicationDataDictionary> for crate::group::extensions::Extension {
287 fn from(val: ApplicationDataDictionary) -> Self {
288 crate::group::extensions::Extension::ApplicationData(val)
289 }
290}
291
292#[derive(
293 Debug,
294 Clone,
295 PartialEq,
296 Eq,
297 tls_codec::TlsSerialize,
298 tls_codec::TlsDeserialize,
299 tls_codec::TlsSize,
300)]
301#[repr(u8)]
302#[cfg_attr(
303 feature = "serde",
304 derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)
305)]
306pub enum ApplicationDataUpdateOperationType {
307 Invalid = 0x00,
308 Update = 0x01,
309 Remove = 0x02,
310}
311
312#[derive(
313 Debug,
314 Clone,
315 PartialEq,
316 Eq,
317 tls_codec::TlsSerialize,
318 tls_codec::TlsDeserialize,
319 tls_codec::TlsSize,
320)]
321#[repr(u8)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323pub enum ApplicationDataUpdateOperation {
324 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Update")]
325 Update {
326 #[tls_codec(with = "crate::tlspl::bytes")]
327 update: Vec<u8>,
328 },
329 #[tls_codec(discriminant = "ApplicationDataUpdateOperationType::Remove")]
330 Remove,
331}
332
333#[derive(
334 Debug,
335 Clone,
336 PartialEq,
337 Eq,
338 tls_codec::TlsSerialize,
339 tls_codec::TlsDeserialize,
340 tls_codec::TlsSize,
341)]
342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
343pub struct AppDataUpdate {
344 pub component_id: ComponentId,
345 pub op: ApplicationDataUpdateOperation,
346}
347
348impl AppDataUpdate {
349 pub fn extract_component_update<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
354 let type_component_id = C::component_id();
355 if type_component_id != self.component_id {
356 return Err(crate::MlsSpecError::SafeAppComponentIdMismatch {
357 expected: type_component_id,
358 actual: self.component_id,
359 });
360 }
361
362 let ApplicationDataUpdateOperation::Update { update } = &self.op else {
363 return Ok(None);
364 };
365
366 Ok(Some(C::from_tls_bytes(update)?))
367 }
368}
369
370#[derive(
371 Debug,
372 Clone,
373 PartialEq,
374 Eq,
375 tls_codec::TlsSerialize,
376 tls_codec::TlsDeserialize,
377 tls_codec::TlsSize,
378)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct ApplicationData {
381 pub component_id: ComponentId,
382 #[tls_codec(with = "crate::tlspl::bytes")]
383 pub data: Vec<u8>,
384}
385
386pub type AppEphemeral = ApplicationData;
387
388#[derive(Debug, Clone, PartialEq, Eq, Hash, tls_codec::TlsSerialize, tls_codec::TlsSize)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize))]
390pub struct SafeAadItemRefOld<'a> {
391 pub component_id: &'a ComponentId,
392 #[tls_codec(with = "crate::tlspl::bytes")]
393 pub aad_item_data: &'a [u8],
394}
395
396#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
397#[cfg_attr(feature = "serde", derive(serde::Serialize))]
398#[cfg_attr(feature = "serde", serde(transparent))]
399pub struct SafeAadItemRef<'a>(ComponentDataRef<'a>);
400
401impl SafeAadItemRef<'_> {
402 pub fn component_id(&self) -> &ComponentId {
403 self.0.component_id
404 }
405
406 pub fn aad_item_data(&self) -> &[u8] {
407 self.0.data
408 }
409}
410
411#[derive(
412 Debug,
413 Clone,
414 PartialEq,
415 Eq,
416 tls_codec::TlsSerialize,
417 tls_codec::TlsDeserialize,
418 tls_codec::TlsSize,
419)]
420#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
421#[cfg_attr(feature = "serde", serde(transparent))]
422pub struct SafeAadItem(ComponentData);
423
424impl SafeAadItem {
425 pub fn as_ref(&self) -> SafeAadItemRef {
426 SafeAadItemRef(self.0.as_ref())
427 }
428}
429
430#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
431#[cfg_attr(feature = "serde", derive(serde::Serialize))]
432pub struct SafeAadRef<'a> {
433 pub aad_items: &'a [&'a SafeAadItemRef<'a>],
434}
435
436#[derive(
437 Debug,
438 Clone,
439 PartialEq,
440 Eq,
441 tls_codec::TlsSerialize,
442 tls_codec::TlsDeserialize,
443 tls_codec::TlsSize,
444)]
445#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
446pub struct SafeAad {
447 aad_items: ComponentDataMap,
448}
449
450impl SafeAad {
451 pub fn iter_components(&self) -> impl Iterator<Item = SafeAadItemRef> {
452 self.aad_items
453 .iter()
454 .map(|(component_id, data)| SafeAadItemRef(ComponentDataRef { component_id, data }))
455 }
456
457 pub fn extract_component<C: Component>(&self) -> crate::MlsSpecResult<Option<C>> {
458 self.aad_items.extract_component::<C>()
459 }
460
461 pub fn insert_or_update_component<C: Component>(
463 &mut self,
464 component: &C,
465 ) -> crate::MlsSpecResult<bool> {
466 self.aad_items.insert_or_update_component(component)
467 }
468}
469
470#[derive(
471 Debug,
472 Clone,
473 PartialEq,
474 Eq,
475 tls_codec::TlsSerialize,
476 tls_codec::TlsDeserialize,
477 tls_codec::TlsSize,
478)]
479#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
480pub struct WireFormats {
481 pub wire_formats: Vec<crate::defs::WireFormat>,
482}
483
484#[derive(
485 Debug,
486 Clone,
487 PartialEq,
488 Eq,
489 tls_codec::TlsSerialize,
490 tls_codec::TlsDeserialize,
491 tls_codec::TlsSize,
492)]
493#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
494pub struct ComponentsList {
495 pub component_ids: Vec<ComponentId>,
496}
497
498#[cfg(test)]
499mod tests {
500 use std::collections::BTreeMap;
501
502 use super::{ApplicationDataDictionary, SafeAad};
503 use crate::generate_roundtrip_test;
504
505 generate_roundtrip_test!(can_roundtrip_appdatadict, {
506 ApplicationDataDictionary {
507 component_data: super::ComponentDataMap(BTreeMap::from([
508 (1, vec![1]),
509 (3, vec![3]),
510 (2, vec![2]),
511 ])),
512 }
513 });
514
515 generate_roundtrip_test!(can_roundtrip_safeaad, {
516 SafeAad {
517 aad_items: super::ComponentDataMap(BTreeMap::from([
518 (1, vec![1]),
519 (3, vec![3]),
520 (2, vec![2]),
521 ])),
522 }
523 });
524}