1use alloc::vec::Vec;
6use core::{
7 fmt::{self, Debug},
8 ops::{Deref, DerefMut},
9};
10
11use zeroize::Zeroizing;
12
13use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider};
14
15use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
16use mls_rs_core::error::IntoAnyError;
17
18use super::key_schedule::kdf_expand_with_label;
19
20pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
21
22#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[repr(u8)]
25enum SecretTreeNode {
26 Secret(TreeSecret) = 0u8,
27 Ratchet(SecretRatchets) = 1u8,
28}
29
30impl SecretTreeNode {
31 fn into_secret(self) -> Option<TreeSecret> {
32 if let SecretTreeNode::Secret(secret) = self {
33 Some(secret)
34 } else {
35 None
36 }
37 }
38}
39
40#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42struct TreeSecret(
43 #[mls_codec(with = "mls_rs_codec::byte_vec")]
44 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
45 Zeroizing<Vec<u8>>,
46);
47
48impl Debug for TreeSecret {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("TreeSecret").finish()
51 }
52}
53
54impl Deref for TreeSecret {
55 type Target = Vec<u8>;
56
57 fn deref(&self) -> &Self::Target {
58 &self.0
59 }
60}
61
62impl DerefMut for TreeSecret {
63 fn deref_mut(&mut self) -> &mut Self::Target {
64 &mut self.0
65 }
66}
67
68impl AsRef<[u8]> for TreeSecret {
69 fn as_ref(&self) -> &[u8] {
70 &self.0
71 }
72}
73
74impl From<Vec<u8>> for TreeSecret {
75 fn from(vec: Vec<u8>) -> Self {
76 TreeSecret(Zeroizing::new(vec))
77 }
78}
79
80impl From<Zeroizing<Vec<u8>>> for TreeSecret {
81 fn from(vec: Zeroizing<Vec<u8>>) -> Self {
82 TreeSecret(vec)
83 }
84}
85
86#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88struct TreeSecretsVec<T: TreeIndex> {
89 inner: LargeMap<T, SecretTreeNode>,
90}
91
92impl<T: TreeIndex> TreeSecretsVec<T> {
93 fn set_node(&mut self, index: T, value: SecretTreeNode) {
94 self.inner.insert(index, value);
95 }
96
97 fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
98 self.inner.remove(index)
99 }
100}
101
102#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104pub struct SecretTree<T: TreeIndex> {
105 known_secrets: TreeSecretsVec<T>,
106 leaf_count: T,
107}
108
109impl<T: TreeIndex> SecretTree<T> {
110 pub(crate) fn empty() -> SecretTree<T> {
111 SecretTree {
112 known_secrets: Default::default(),
113 leaf_count: T::zero(),
114 }
115 }
116}
117
118#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub struct SecretRatchets {
121 pub application: SecretKeyRatchet,
122 pub handshake: SecretKeyRatchet,
123}
124
125impl SecretRatchets {
126 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
127 pub async fn message_key_generation<P: CipherSuiteProvider>(
128 &mut self,
129 cipher_suite_provider: &P,
130 generation: u32,
131 key_type: KeyType,
132 ) -> Result<MessageKeyData, MlsError> {
133 match key_type {
134 KeyType::Handshake => {
135 self.handshake
136 .get_message_key(cipher_suite_provider, generation)
137 .await
138 }
139 KeyType::Application => {
140 self.application
141 .get_message_key(cipher_suite_provider, generation)
142 .await
143 }
144 }
145 }
146
147 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
148 pub async fn next_message_key<P: CipherSuiteProvider>(
149 &mut self,
150 cipher_suite: &P,
151 key_type: KeyType,
152 ) -> Result<MessageKeyData, MlsError> {
153 match key_type {
154 KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
155 KeyType::Application => self.application.next_message_key(cipher_suite).await,
156 }
157 }
158
159 #[cfg(feature = "export_key_generation")]
162 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
163 async fn peek_next_key_generation(&self, key_type: KeyType) -> u32 {
164 match key_type {
165 KeyType::Handshake => self.handshake.peek_next_key_generation().await,
166 KeyType::Application => self.application.peek_next_key_generation().await,
167 }
168 }
169}
170
171impl<T: TreeIndex> SecretTree<T> {
172 pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
173 let mut known_secrets = TreeSecretsVec::default();
174
175 let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
176 known_secrets.set_node(leaf_count.root(), root_secret);
177
178 Self {
179 known_secrets,
180 leaf_count,
181 }
182 }
183
184 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
185 async fn consume_node<P: CipherSuiteProvider>(
186 &mut self,
187 cipher_suite_provider: &P,
188 index: &T,
189 ) -> Result<(), MlsError> {
190 let node = self.known_secrets.take_node(index);
191
192 if let Some(secret) = node.and_then(|n| n.into_secret()) {
193 let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
194 let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
195
196 let left_secret =
197 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
198 .await?;
199
200 let right_secret =
201 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
202 .await?;
203
204 self.known_secrets
205 .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
206
207 self.known_secrets
208 .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
209 }
210
211 Ok(())
212 }
213
214 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
215 async fn take_leaf_ratchet<P: CipherSuiteProvider>(
216 &mut self,
217 cipher_suite: &P,
218 leaf_index: &T,
219 ) -> Result<SecretRatchets, MlsError> {
220 let node_index = leaf_index;
221
222 let node = match self.known_secrets.take_node(node_index) {
223 Some(node) => node,
224 None => {
225 for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
227 self.consume_node(cipher_suite, &i.path).await?;
228 }
229
230 self.known_secrets
231 .take_node(node_index)
232 .ok_or(MlsError::InvalidLeafConsumption)?
233 }
234 };
235
236 Ok(match node {
237 SecretTreeNode::Ratchet(ratchet) => ratchet,
238 SecretTreeNode::Secret(secret) => SecretRatchets {
239 application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
240 .await?,
241 handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
242 },
243 })
244 }
245
246 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
247 pub async fn next_message_key<P: CipherSuiteProvider>(
248 &mut self,
249 cipher_suite: &P,
250 leaf_index: T,
251 key_type: KeyType,
252 ) -> Result<MessageKeyData, MlsError> {
253 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
254 let res = ratchet.next_message_key(cipher_suite, key_type).await?;
255
256 self.known_secrets
257 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
258
259 Ok(res)
260 }
261
262 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
263 pub async fn message_key_generation<P: CipherSuiteProvider>(
264 &mut self,
265 cipher_suite: &P,
266 leaf_index: T,
267 key_type: KeyType,
268 generation: u32,
269 ) -> Result<MessageKeyData, MlsError> {
270 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
271
272 let res = ratchet
273 .message_key_generation(cipher_suite, generation, key_type)
274 .await;
275
276 self.known_secrets
277 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
278
279 res
280 }
281
282 #[cfg(feature = "export_key_generation")]
292 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
293 pub async fn peek_next_key_generation<P: CipherSuiteProvider>(
294 &mut self,
295 cipher_suite: &P,
296 leaf_index: T,
297 key_type: KeyType,
298 ) -> Result<u32, MlsError> {
299 let ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
300 let res = ratchet.peek_next_key_generation(key_type).await;
301
302 self.known_secrets
303 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
304
305 Ok(res)
306 }
307}
308
309#[derive(Clone, Copy)]
310pub enum KeyType {
311 Handshake,
312 Application,
313}
314
315#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
316#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
317pub struct MessageKeyData {
319 #[mls_codec(with = "mls_rs_codec::byte_vec")]
320 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
321 pub(crate) nonce: Zeroizing<Vec<u8>>,
322 #[mls_codec(with = "mls_rs_codec::byte_vec")]
323 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
324 pub(crate) key: Zeroizing<Vec<u8>>,
325 pub(crate) generation: u32,
326}
327
328impl Debug for MessageKeyData {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 f.debug_struct("MessageKeyData")
331 .field("generation", &self.generation)
332 .finish()
333 }
334}
335
336impl MessageKeyData {
337 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
339 pub fn nonce(&self) -> &[u8] {
340 &self.nonce
341 }
342
343 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
345 pub fn key(&self) -> &[u8] {
346 &self.key
347 }
348
349 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
351 pub fn generation(&self) -> u32 {
352 self.generation
353 }
354}
355
356#[derive(Debug, Clone, PartialEq)]
357#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
358pub struct SecretKeyRatchet {
359 secret: TreeSecret,
360 generation: u32,
361 #[cfg(feature = "out_of_order")]
362 history: LargeMap<u32, MessageKeyData>,
363}
364
365impl MlsSize for SecretKeyRatchet {
366 fn mls_encoded_len(&self) -> usize {
367 let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
368 + self.generation.mls_encoded_len();
369
370 #[cfg(feature = "out_of_order")]
371 return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
372 #[cfg(not(feature = "out_of_order"))]
373 return len;
374 }
375}
376
377#[cfg(feature = "out_of_order")]
378impl MlsEncode for SecretKeyRatchet {
379 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
380 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
381 self.generation.mls_encode(writer)?;
382 mls_rs_codec::iter::mls_encode(self.history.values(), writer)
383 }
384}
385
386#[cfg(not(feature = "out_of_order"))]
387impl MlsEncode for SecretKeyRatchet {
388 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
389 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
390 self.generation.mls_encode(writer)
391 }
392}
393
394impl MlsDecode for SecretKeyRatchet {
395 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
396 Ok(Self {
397 secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
398 generation: u32::mls_decode(reader)?,
399 #[cfg(feature = "out_of_order")]
400 history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
401 let mut items = LargeMap::default();
402
403 while !data.is_empty() {
404 let item = MessageKeyData::mls_decode(data)?;
405 items.insert(item.generation, item);
406 }
407
408 Ok(items)
409 })?,
410 })
411 }
412}
413
414impl SecretKeyRatchet {
415 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
416 async fn new<P: CipherSuiteProvider>(
417 cipher_suite_provider: &P,
418 secret: &[u8],
419 key_type: KeyType,
420 ) -> Result<Self, MlsError> {
421 let label = match key_type {
422 KeyType::Handshake => b"handshake".as_slice(),
423 KeyType::Application => b"application".as_slice(),
424 };
425
426 let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
427 .await
428 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
429
430 Ok(Self {
431 secret: TreeSecret::from(secret),
432 generation: 0,
433 #[cfg(feature = "out_of_order")]
434 history: Default::default(),
435 })
436 }
437
438 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
439 async fn get_message_key<P: CipherSuiteProvider>(
440 &mut self,
441 cipher_suite_provider: &P,
442 generation: u32,
443 ) -> Result<MessageKeyData, MlsError> {
444 #[cfg(feature = "out_of_order")]
445 if generation < self.generation {
446 return self
447 .history
448 .remove_entry(&generation)
449 .map(|(_, mk)| mk)
450 .ok_or(MlsError::KeyMissing(generation));
451 }
452
453 #[cfg(not(feature = "out_of_order"))]
454 if generation < self.generation {
455 return Err(MlsError::KeyMissing(generation));
456 }
457
458 let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
459
460 if generation > max_generation_allowed {
461 return Err(MlsError::InvalidFutureGeneration(generation));
462 }
463
464 #[cfg(not(feature = "out_of_order"))]
465 while self.generation < generation {
466 self.next_message_key(cipher_suite_provider)?;
467 }
468
469 #[cfg(feature = "out_of_order")]
470 while self.generation < generation {
471 let key_data = self.next_message_key(cipher_suite_provider).await?;
472 self.history.insert(key_data.generation, key_data);
473 }
474
475 self.next_message_key(cipher_suite_provider).await
476 }
477
478 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
479 async fn next_message_key<P: CipherSuiteProvider>(
480 &mut self,
481 cipher_suite_provider: &P,
482 ) -> Result<MessageKeyData, MlsError> {
483 let generation = self.generation;
484
485 let key = MessageKeyData {
486 nonce: self
487 .derive_secret(
488 cipher_suite_provider,
489 b"nonce",
490 cipher_suite_provider.aead_nonce_size(),
491 )
492 .await?,
493 key: self
494 .derive_secret(
495 cipher_suite_provider,
496 b"key",
497 cipher_suite_provider.aead_key_size(),
498 )
499 .await?,
500 generation,
501 };
502
503 self.secret = self
504 .derive_secret(
505 cipher_suite_provider,
506 b"secret",
507 cipher_suite_provider.kdf_extract_size(),
508 )
509 .await?
510 .into();
511
512 self.generation = generation + 1;
513
514 Ok(key)
515 }
516
517 #[cfg(feature = "export_key_generation")]
520 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
521 async fn peek_next_key_generation(&self) -> u32 {
522 self.generation
523 }
524
525 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
526 async fn derive_secret<P: CipherSuiteProvider>(
527 &self,
528 cipher_suite_provider: &P,
529 label: &[u8],
530 len: usize,
531 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
532 kdf_expand_with_label(
533 cipher_suite_provider,
534 self.secret.as_ref(),
535 label,
536 &self.generation.to_be_bytes(),
537 Some(len),
538 )
539 .await
540 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
541 }
542}
543
544#[cfg(test)]
545pub(crate) mod test_utils {
546 use alloc::{string::String, vec::Vec};
547 use mls_rs_core::crypto::CipherSuiteProvider;
548 use zeroize::Zeroizing;
549
550 use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
551
552 use super::{KeyType, SecretKeyRatchet, SecretTree};
553
554 pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
555 SecretTree::new(leaf_count, Zeroizing::new(secret))
556 }
557
558 impl SecretTree<u32> {
559 pub(crate) fn get_root_secret(&self) -> Vec<u8> {
560 self.known_secrets
561 .clone()
562 .take_node(&self.leaf_count.root())
563 .unwrap()
564 .into_secret()
565 .unwrap()
566 .to_vec()
567 }
568 }
569
570 #[derive(Debug, serde::Serialize, serde::Deserialize)]
571 pub struct RatchetInteropTestCase {
572 #[serde(with = "hex::serde")]
573 secret: Vec<u8>,
574 label: String,
575 generation: u32,
576 length: usize,
577 #[serde(with = "hex::serde")]
578 out: Vec<u8>,
579 }
580
581 #[derive(Debug, serde::Serialize, serde::Deserialize)]
582 pub struct InteropTestCase {
583 cipher_suite: u16,
584 derive_tree_secret: RatchetInteropTestCase,
585 }
586
587 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
588 async fn test_basic_crypto_test_vectors() {
589 let test_cases: Vec<InteropTestCase> =
590 load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
591
592 for test_case in test_cases {
593 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
594 test_case.derive_tree_secret.verify(&cs).await
595 }
596 }
597 }
598
599 impl RatchetInteropTestCase {
600 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
601 pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
602 let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
603 .await
604 .unwrap();
605
606 ratchet.secret = self.secret.clone().into();
607 ratchet.generation = self.generation;
608
609 let computed = ratchet
610 .derive_secret(cs, self.label.as_bytes(), self.length)
611 .await
612 .unwrap();
613
614 assert_eq!(&computed.to_vec(), &self.out);
615 }
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use alloc::vec;
622
623 use crate::{
624 cipher_suite::CipherSuite,
625 client::test_utils::TEST_CIPHER_SUITE,
626 crypto::test_utils::{
627 test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
628 },
629 tree_kem::node::NodeIndex,
630 };
631
632 #[cfg(not(mls_build_async))]
633 use crate::group::test_utils::random_bytes;
634
635 use super::{test_utils::get_test_tree, *};
636
637 use assert_matches::assert_matches;
638
639 #[cfg(target_arch = "wasm32")]
640 use wasm_bindgen_test::wasm_bindgen_test as test;
641
642 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
643 async fn test_secret_tree() {
644 test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
645 test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
646 }
647
648 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
649 async fn test_secret_tree_custom<T: TreeIndex>(
650 leaf_count: T,
651 leaves_to_check: Vec<T>,
652 all_deleted: bool,
653 ) {
654 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
655 let cs_provider = test_cipher_suite_provider(cipher_suite);
656
657 let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
658 let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
659
660 let mut secrets = Vec::<SecretRatchets>::new();
661
662 for i in &leaves_to_check {
663 let secret = test_tree
664 .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
665 .await
666 .unwrap();
667
668 secrets.push(secret);
669 }
670
671 assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
673
674 let count = secrets.len();
676 secrets.dedup();
677 assert_eq!(count, secrets.len());
678 }
679 }
680
681 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
682 async fn test_secret_key_ratchet() {
683 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
684 let provider = test_cipher_suite_provider(cipher_suite);
685
686 let mut app_ratchet = SecretKeyRatchet::new(
687 &provider,
688 &vec![0u8; provider.kdf_extract_size()],
689 KeyType::Application,
690 )
691 .await
692 .unwrap();
693
694 let mut handshake_ratchet = SecretKeyRatchet::new(
695 &provider,
696 &vec![0u8; provider.kdf_extract_size()],
697 KeyType::Handshake,
698 )
699 .await
700 .unwrap();
701
702 let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
703 let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
704 let app_keys = vec![app_key_one, app_key_two];
705
706 let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
707 let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
708 let handshake_keys = vec![handshake_key_one, handshake_key_two];
709
710 assert_ne!(app_keys, handshake_keys);
712
713 assert_ne!(handshake_keys[0], handshake_keys[1]);
715 }
716 }
717
718 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
719 async fn test_get_key() {
720 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
721 let provider = test_cipher_suite_provider(cipher_suite);
722
723 let mut ratchet = SecretKeyRatchet::new(
724 &test_cipher_suite_provider(cipher_suite),
725 &vec![0u8; provider.kdf_extract_size()],
726 KeyType::Application,
727 )
728 .await
729 .unwrap();
730
731 let mut ratchet_clone = ratchet.clone();
732
733 let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
735 let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
736
737 let res = ratchet_clone.get_message_key(&provider, 0).await;
739 assert!(res.is_err());
740
741 let second_key = ratchet
743 .get_message_key(&provider, ratchet_clone.generation - 1)
744 .await
745 .unwrap();
746
747 assert_eq!(clone_2, second_key)
748 }
749 }
750
751 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
752 async fn test_secret_ratchet() {
753 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
754 let provider = test_cipher_suite_provider(cipher_suite);
755
756 let mut ratchet = SecretKeyRatchet::new(
757 &provider,
758 &vec![0u8; provider.kdf_extract_size()],
759 KeyType::Application,
760 )
761 .await
762 .unwrap();
763
764 let original_secret = ratchet.secret.clone();
765 let _ = ratchet.next_message_key(&provider).await.unwrap();
766 let new_secret = ratchet.secret;
767 assert_ne!(original_secret, new_secret)
768 }
769 }
770
771 #[cfg(feature = "out_of_order")]
772 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
773 async fn test_out_of_order_keys() {
774 let cipher_suite = TEST_CIPHER_SUITE;
775 let provider = test_cipher_suite_provider(cipher_suite);
776
777 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
778 .await
779 .unwrap();
780 let mut ratchet_clone = ratchet.clone();
781
782 let mut ordered_keys = Vec::<MessageKeyData>::new();
784
785 for i in 0..=MAX_RATCHET_BACK_HISTORY {
786 ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
787 }
788
789 let last_key = ratchet_clone
791 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
792 .await
793 .unwrap();
794
795 assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
796
797 let mut back_history_keys = Vec::<MessageKeyData>::new();
799
800 for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
801 back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
802 }
803
804 assert_eq!(
805 back_history_keys,
806 ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
807 );
808 }
809
810 #[cfg(not(feature = "out_of_order"))]
811 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
812 async fn out_of_order_keys_should_throw_error() {
813 let cipher_suite = TEST_CIPHER_SUITE;
814 let provider = test_cipher_suite_provider(cipher_suite);
815
816 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
817 .await
818 .unwrap();
819
820 ratchet.get_message_key(&provider, 10).await.unwrap();
821 let res = ratchet.get_message_key(&provider, 9).await;
822 assert_matches!(res, Err(MlsError::KeyMissing(9)))
823 }
824
825 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
826 async fn test_too_out_of_order() {
827 let cipher_suite = TEST_CIPHER_SUITE;
828 let provider = test_cipher_suite_provider(cipher_suite);
829
830 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
831 .await
832 .unwrap();
833
834 let res = ratchet
835 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
836 .await;
837
838 let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
839
840 assert_matches!(
841 res,
842 Err(MlsError::InvalidFutureGeneration(invalid))
843 if invalid == invalid_generation
844 )
845 }
846
847 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
848 async fn double_hit_leaves_epoch_intact() {
849 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
850 let test_secret = vec![0u8; cs.kdf_extract_size()];
851 let mut test_tree = get_test_tree(test_secret, 4u32);
852 let key_type = KeyType::Application;
853
854 test_tree
856 .message_key_generation(&cs, 4, key_type, 0)
857 .await
858 .unwrap();
859
860 let res = test_tree.message_key_generation(&cs, 4, key_type, 0).await;
862 assert_matches!(res, Err(MlsError::KeyMissing(0)));
863
864 test_tree
866 .message_key_generation(&cs, 4, key_type, 1)
867 .await
868 .unwrap();
869 }
870
871 #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
872 struct Ratchet {
873 application_keys: Vec<Vec<u8>>,
874 handshake_keys: Vec<Vec<u8>>,
875 }
876
877 #[derive(Debug, serde::Serialize, serde::Deserialize)]
878 struct TestCase {
879 cipher_suite: u16,
880 #[serde(with = "hex::serde")]
881 encryption_secret: Vec<u8>,
882 ratchets: Vec<Ratchet>,
883 }
884
885 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
886 async fn get_ratchet_data(
887 secret_tree: &mut SecretTree<NodeIndex>,
888 cipher_suite: CipherSuite,
889 ) -> Vec<Ratchet> {
890 let provider = test_cipher_suite_provider(cipher_suite);
891 let mut ratchet_data = Vec::new();
892
893 for index in 0..16 {
894 let mut ratchets = secret_tree
895 .take_leaf_ratchet(&provider, &(index * 2))
896 .await
897 .unwrap();
898
899 let mut application_keys = Vec::new();
900
901 for _ in 0..20 {
902 let key = ratchets
903 .handshake
904 .next_message_key(&provider)
905 .await
906 .unwrap()
907 .mls_encode_to_vec()
908 .unwrap();
909
910 application_keys.push(key);
911 }
912
913 let mut handshake_keys = Vec::new();
914
915 for _ in 0..20 {
916 let key = ratchets
917 .handshake
918 .next_message_key(&provider)
919 .await
920 .unwrap()
921 .mls_encode_to_vec()
922 .unwrap();
923
924 handshake_keys.push(key);
925 }
926
927 ratchet_data.push(Ratchet {
928 application_keys,
929 handshake_keys,
930 });
931 }
932
933 ratchet_data
934 }
935
936 #[cfg(not(mls_build_async))]
937 #[cfg_attr(coverage_nightly, coverage(off))]
938 fn generate_test_vector() -> Vec<TestCase> {
939 CipherSuite::all()
940 .map(|cipher_suite| {
941 let provider = test_cipher_suite_provider(cipher_suite);
942 let encryption_secret = random_bytes(provider.kdf_extract_size());
943
944 let mut secret_tree =
945 SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
946
947 TestCase {
948 cipher_suite: cipher_suite.into(),
949 encryption_secret,
950 ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
951 }
952 })
953 .collect()
954 }
955
956 #[cfg(mls_build_async)]
957 fn generate_test_vector() -> Vec<TestCase> {
958 panic!("Tests cannot be generated in async mode");
959 }
960
961 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
962 async fn test_secret_tree_test_vectors() {
963 let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
964
965 for case in test_cases {
966 let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
967 continue;
968 };
969
970 let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
971 let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
972
973 assert_eq!(ratchet_data, case.ratchets);
974 }
975 }
976
977 #[cfg(feature = "export_key_generation")]
978 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
979 async fn peek_next_key_generation() {
980 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
981 let provider = test_cipher_suite_provider(cipher_suite);
982
983 let app_ratchet = SecretKeyRatchet::new(
984 &provider,
985 &vec![0u8; provider.kdf_extract_size()],
986 KeyType::Application,
987 )
988 .await
989 .unwrap();
990
991 let handshake_ratchet = SecretKeyRatchet::new(
992 &provider,
993 &vec![0u8; provider.kdf_extract_size()],
994 KeyType::Handshake,
995 )
996 .await
997 .unwrap();
998
999 for mut ratchet in [app_ratchet, handshake_ratchet] {
1000 assert_eq!(ratchet.peek_next_key_generation(), 0);
1001 assert_eq!(ratchet.peek_next_key_generation(), 0);
1002
1003 let key_zero = ratchet.next_message_key(&provider).await.unwrap();
1004 assert_eq!(key_zero.generation, 0);
1005
1006 assert_eq!(ratchet.peek_next_key_generation(), 1);
1007 let _key_one = ratchet.next_message_key(&provider).await.unwrap();
1008
1009 let key_two = ratchet.next_message_key(&provider).await.unwrap();
1010 assert_eq!(key_two.generation, 2);
1011
1012 assert_eq!(ratchet.peek_next_key_generation(), 3);
1013 }
1014 }
1015 }
1016}
1017
1018#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
1019mod interop_tests {
1020 #[cfg(not(mls_build_async))]
1021 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
1022 use zeroize::Zeroizing;
1023
1024 use crate::{
1025 crypto::test_utils::try_test_cipher_suite_provider,
1026 group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
1027 };
1028
1029 use super::SecretTree;
1030
1031 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1032 async fn interop_test_vector() {
1033 let test_cases = load_interop_test_cases();
1035
1036 for case in test_cases {
1037 let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
1038 continue;
1039 };
1040
1041 case.sender_data.verify(&cs).await;
1042
1043 let mut tree = SecretTree::new(
1044 case.leaves.len() as u32,
1045 Zeroizing::new(case.encryption_secret),
1046 );
1047
1048 for (index, leaves) in case.leaves.iter().enumerate() {
1049 for leaf in leaves.iter() {
1050 let key = tree
1051 .message_key_generation(
1052 &cs,
1053 (index as u32) * 2,
1054 KeyType::Application,
1055 leaf.generation,
1056 )
1057 .await
1058 .unwrap();
1059
1060 assert_eq!(key.key.to_vec(), leaf.application_key);
1061 assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
1062
1063 let key = tree
1064 .message_key_generation(
1065 &cs,
1066 (index as u32) * 2,
1067 KeyType::Handshake,
1068 leaf.generation,
1069 )
1070 .await
1071 .unwrap();
1072
1073 assert_eq!(key.key.to_vec(), leaf.handshake_key);
1074 assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
1075 }
1076 }
1077 }
1078 }
1079
1080 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1081 struct InteropTestCase {
1082 cipher_suite: u16,
1083 #[serde(with = "hex::serde")]
1084 encryption_secret: Vec<u8>,
1085 sender_data: InteropSenderData,
1086 leaves: Vec<Vec<InteropLeaf>>,
1087 }
1088
1089 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1090 struct InteropLeaf {
1091 generation: u32,
1092 #[serde(with = "hex::serde")]
1093 application_key: Vec<u8>,
1094 #[serde(with = "hex::serde")]
1095 application_nonce: Vec<u8>,
1096 #[serde(with = "hex::serde")]
1097 handshake_key: Vec<u8>,
1098 #[serde(with = "hex::serde")]
1099 handshake_nonce: Vec<u8>,
1100 }
1101
1102 fn load_interop_test_cases() -> Vec<InteropTestCase> {
1103 load_test_case_json!(secret_tree_interop, generate_test_vector())
1104 }
1105
1106 #[cfg(not(mls_build_async))]
1107 #[cfg_attr(coverage_nightly, coverage(off))]
1108 fn generate_test_vector() -> Vec<InteropTestCase> {
1109 let mut test_cases = vec![];
1110
1111 for cs in CipherSuite::all() {
1112 let Some(cs) = try_test_cipher_suite_provider(*cs) else {
1113 continue;
1114 };
1115
1116 let gens = [0, 15];
1117 let tree_sizes = [1, 8, 32];
1118
1119 for n_leaves in tree_sizes {
1120 let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
1121
1122 let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
1123
1124 let leaves = (0..n_leaves)
1125 .map(|leaf| {
1126 gens.into_iter()
1127 .map(|gen| {
1128 let index = leaf * 2u32;
1129
1130 let handshake_key = tree
1131 .message_key_generation(&cs, index, KeyType::Handshake, gen)
1132 .unwrap();
1133
1134 let app_key = tree
1135 .message_key_generation(&cs, index, KeyType::Application, gen)
1136 .unwrap();
1137
1138 InteropLeaf {
1139 generation: gen,
1140 application_key: app_key.key.to_vec(),
1141 application_nonce: app_key.nonce.to_vec(),
1142 handshake_key: handshake_key.key.to_vec(),
1143 handshake_nonce: handshake_key.nonce.to_vec(),
1144 }
1145 })
1146 .collect()
1147 })
1148 .collect();
1149
1150 let case = InteropTestCase {
1151 cipher_suite: *cs.cipher_suite(),
1152 encryption_secret,
1153 sender_data: InteropSenderData::new(&cs),
1154 leaves,
1155 };
1156
1157 test_cases.push(case);
1158 }
1159 }
1160
1161 test_cases
1162 }
1163
1164 #[cfg(mls_build_async)]
1165 fn generate_test_vector() -> Vec<InteropTestCase> {
1166 panic!("Tests cannot be generated in async mode");
1167 }
1168}