Skip to main content

mls_rs/group/
secret_tree.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::vec::Vec;
6use core::{
7    fmt::{self, Debug},
8    ops::{Deref, DerefMut},
9};
10
11use zeroize::Zeroizing;
12
13use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider};
14
15use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
16use mls_rs_core::error::IntoAnyError;
17
18use super::key_schedule::kdf_expand_with_label;
19
20pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
21
22#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[repr(u8)]
25enum SecretTreeNode {
26    Secret(TreeSecret) = 0u8,
27    Ratchet(SecretRatchets) = 1u8,
28}
29
30impl SecretTreeNode {
31    fn into_secret(self) -> Option<TreeSecret> {
32        if let SecretTreeNode::Secret(secret) = self {
33            Some(secret)
34        } else {
35            None
36        }
37    }
38}
39
40#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42struct TreeSecret(
43    #[mls_codec(with = "mls_rs_codec::byte_vec")]
44    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
45    Zeroizing<Vec<u8>>,
46);
47
48impl Debug for TreeSecret {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("TreeSecret").finish()
51    }
52}
53
54impl Deref for TreeSecret {
55    type Target = Vec<u8>;
56
57    fn deref(&self) -> &Self::Target {
58        &self.0
59    }
60}
61
62impl DerefMut for TreeSecret {
63    fn deref_mut(&mut self) -> &mut Self::Target {
64        &mut self.0
65    }
66}
67
68impl AsRef<[u8]> for TreeSecret {
69    fn as_ref(&self) -> &[u8] {
70        &self.0
71    }
72}
73
74impl From<Vec<u8>> for TreeSecret {
75    fn from(vec: Vec<u8>) -> Self {
76        TreeSecret(Zeroizing::new(vec))
77    }
78}
79
80impl From<Zeroizing<Vec<u8>>> for TreeSecret {
81    fn from(vec: Zeroizing<Vec<u8>>) -> Self {
82        TreeSecret(vec)
83    }
84}
85
86#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88struct TreeSecretsVec<T: TreeIndex> {
89    inner: LargeMap<T, SecretTreeNode>,
90}
91
92impl<T: TreeIndex> TreeSecretsVec<T> {
93    fn set_node(&mut self, index: T, value: SecretTreeNode) {
94        self.inner.insert(index, value);
95    }
96
97    fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
98        self.inner.remove(index)
99    }
100}
101
102#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104pub struct SecretTree<T: TreeIndex> {
105    known_secrets: TreeSecretsVec<T>,
106    leaf_count: T,
107}
108
109impl<T: TreeIndex> SecretTree<T> {
110    pub(crate) fn empty() -> SecretTree<T> {
111        SecretTree {
112            known_secrets: Default::default(),
113            leaf_count: T::zero(),
114        }
115    }
116}
117
118#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub struct SecretRatchets {
121    pub application: SecretKeyRatchet,
122    pub handshake: SecretKeyRatchet,
123}
124
125impl SecretRatchets {
126    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
127    pub async fn message_key_generation<P: CipherSuiteProvider>(
128        &mut self,
129        cipher_suite_provider: &P,
130        generation: u32,
131        key_type: KeyType,
132    ) -> Result<MessageKeyData, MlsError> {
133        match key_type {
134            KeyType::Handshake => {
135                self.handshake
136                    .get_message_key(cipher_suite_provider, generation)
137                    .await
138            }
139            KeyType::Application => {
140                self.application
141                    .get_message_key(cipher_suite_provider, generation)
142                    .await
143            }
144        }
145    }
146
147    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
148    pub async fn next_message_key<P: CipherSuiteProvider>(
149        &mut self,
150        cipher_suite: &P,
151        key_type: KeyType,
152    ) -> Result<MessageKeyData, MlsError> {
153        match key_type {
154            KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
155            KeyType::Application => self.application.next_message_key(cipher_suite).await,
156        }
157    }
158
159    /// Peeks at the next key generation for `key_type`, but does not increment the
160    /// generation nor derive keys.
161    #[cfg(feature = "export_key_generation")]
162    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
163    async fn peek_next_key_generation(&self, key_type: KeyType) -> u32 {
164        match key_type {
165            KeyType::Handshake => self.handshake.peek_next_key_generation().await,
166            KeyType::Application => self.application.peek_next_key_generation().await,
167        }
168    }
169}
170
171impl<T: TreeIndex> SecretTree<T> {
172    pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
173        let mut known_secrets = TreeSecretsVec::default();
174
175        let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
176        known_secrets.set_node(leaf_count.root(), root_secret);
177
178        Self {
179            known_secrets,
180            leaf_count,
181        }
182    }
183
184    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
185    async fn consume_node<P: CipherSuiteProvider>(
186        &mut self,
187        cipher_suite_provider: &P,
188        index: &T,
189    ) -> Result<(), MlsError> {
190        let node = self.known_secrets.take_node(index);
191
192        if let Some(secret) = node.and_then(|n| n.into_secret()) {
193            let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
194            let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
195
196            let left_secret =
197                kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
198                    .await?;
199
200            let right_secret =
201                kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
202                    .await?;
203
204            self.known_secrets
205                .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
206
207            self.known_secrets
208                .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
209        }
210
211        Ok(())
212    }
213
214    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
215    async fn take_leaf_ratchet<P: CipherSuiteProvider>(
216        &mut self,
217        cipher_suite: &P,
218        leaf_index: &T,
219    ) -> Result<SecretRatchets, MlsError> {
220        let node_index = leaf_index;
221
222        let node = match self.known_secrets.take_node(node_index) {
223            Some(node) => node,
224            None => {
225                // Start at the root node and work your way down consuming any intermediates needed
226                for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
227                    self.consume_node(cipher_suite, &i.path).await?;
228                }
229
230                self.known_secrets
231                    .take_node(node_index)
232                    .ok_or(MlsError::InvalidLeafConsumption)?
233            }
234        };
235
236        Ok(match node {
237            SecretTreeNode::Ratchet(ratchet) => ratchet,
238            SecretTreeNode::Secret(secret) => SecretRatchets {
239                application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
240                    .await?,
241                handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
242            },
243        })
244    }
245
246    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
247    pub async fn next_message_key<P: CipherSuiteProvider>(
248        &mut self,
249        cipher_suite: &P,
250        leaf_index: T,
251        key_type: KeyType,
252    ) -> Result<MessageKeyData, MlsError> {
253        let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
254        let res = ratchet.next_message_key(cipher_suite, key_type).await?;
255
256        self.known_secrets
257            .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
258
259        Ok(res)
260    }
261
262    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
263    pub async fn message_key_generation<P: CipherSuiteProvider>(
264        &mut self,
265        cipher_suite: &P,
266        leaf_index: T,
267        key_type: KeyType,
268        generation: u32,
269    ) -> Result<MessageKeyData, MlsError> {
270        let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
271
272        let res = ratchet
273            .message_key_generation(cipher_suite, generation, key_type)
274            .await;
275
276        self.known_secrets
277            .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
278
279        res
280    }
281
282    /// Peeks at the next key generation, but does not increment the generation nor
283    /// derive keys.
284    ///
285    /// Takes &mut self since take_leaf_ratchet constructs and stores nodes in the
286    /// SecretTree the first time they are requested.
287    ///
288    /// Called by [`Group::peek_next_key_generation`], which is used by clients to
289    /// authenticate the generation to defend against in-group forgery attacks described
290    /// in https://eprint.iacr.org/2025/554
291    #[cfg(feature = "export_key_generation")]
292    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
293    pub async fn peek_next_key_generation<P: CipherSuiteProvider>(
294        &mut self,
295        cipher_suite: &P,
296        leaf_index: T,
297        key_type: KeyType,
298    ) -> Result<u32, MlsError> {
299        let ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
300        let res = ratchet.peek_next_key_generation(key_type).await;
301
302        self.known_secrets
303            .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
304
305        Ok(res)
306    }
307}
308
309#[derive(Clone, Copy)]
310pub enum KeyType {
311    Handshake,
312    Application,
313}
314
315#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
316#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
317/// AEAD key derived by the MLS secret tree.
318pub struct MessageKeyData {
319    #[mls_codec(with = "mls_rs_codec::byte_vec")]
320    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
321    pub(crate) nonce: Zeroizing<Vec<u8>>,
322    #[mls_codec(with = "mls_rs_codec::byte_vec")]
323    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
324    pub(crate) key: Zeroizing<Vec<u8>>,
325    pub(crate) generation: u32,
326}
327
328impl Debug for MessageKeyData {
329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330        f.debug_struct("MessageKeyData")
331            .field("generation", &self.generation)
332            .finish()
333    }
334}
335
336impl MessageKeyData {
337    /// AEAD nonce.
338    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
339    pub fn nonce(&self) -> &[u8] {
340        &self.nonce
341    }
342
343    /// AEAD key.
344    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
345    pub fn key(&self) -> &[u8] {
346        &self.key
347    }
348
349    /// Generation of this key within the key schedule.
350    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
351    pub fn generation(&self) -> u32 {
352        self.generation
353    }
354}
355
356#[derive(Debug, Clone, PartialEq)]
357#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
358pub struct SecretKeyRatchet {
359    secret: TreeSecret,
360    generation: u32,
361    #[cfg(feature = "out_of_order")]
362    history: LargeMap<u32, MessageKeyData>,
363}
364
365impl MlsSize for SecretKeyRatchet {
366    fn mls_encoded_len(&self) -> usize {
367        let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
368            + self.generation.mls_encoded_len();
369
370        #[cfg(feature = "out_of_order")]
371        return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
372        #[cfg(not(feature = "out_of_order"))]
373        return len;
374    }
375}
376
377#[cfg(feature = "out_of_order")]
378impl MlsEncode for SecretKeyRatchet {
379    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
380        mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
381        self.generation.mls_encode(writer)?;
382        mls_rs_codec::iter::mls_encode(self.history.values(), writer)
383    }
384}
385
386#[cfg(not(feature = "out_of_order"))]
387impl MlsEncode for SecretKeyRatchet {
388    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
389        mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
390        self.generation.mls_encode(writer)
391    }
392}
393
394impl MlsDecode for SecretKeyRatchet {
395    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
396        Ok(Self {
397            secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
398            generation: u32::mls_decode(reader)?,
399            #[cfg(feature = "out_of_order")]
400            history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
401                let mut items = LargeMap::default();
402
403                while !data.is_empty() {
404                    let item = MessageKeyData::mls_decode(data)?;
405                    items.insert(item.generation, item);
406                }
407
408                Ok(items)
409            })?,
410        })
411    }
412}
413
414impl SecretKeyRatchet {
415    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
416    async fn new<P: CipherSuiteProvider>(
417        cipher_suite_provider: &P,
418        secret: &[u8],
419        key_type: KeyType,
420    ) -> Result<Self, MlsError> {
421        let label = match key_type {
422            KeyType::Handshake => b"handshake".as_slice(),
423            KeyType::Application => b"application".as_slice(),
424        };
425
426        let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
427            .await
428            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
429
430        Ok(Self {
431            secret: TreeSecret::from(secret),
432            generation: 0,
433            #[cfg(feature = "out_of_order")]
434            history: Default::default(),
435        })
436    }
437
438    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
439    async fn get_message_key<P: CipherSuiteProvider>(
440        &mut self,
441        cipher_suite_provider: &P,
442        generation: u32,
443    ) -> Result<MessageKeyData, MlsError> {
444        #[cfg(feature = "out_of_order")]
445        if generation < self.generation {
446            return self
447                .history
448                .remove_entry(&generation)
449                .map(|(_, mk)| mk)
450                .ok_or(MlsError::KeyMissing(generation));
451        }
452
453        #[cfg(not(feature = "out_of_order"))]
454        if generation < self.generation {
455            return Err(MlsError::KeyMissing(generation));
456        }
457
458        let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
459
460        if generation > max_generation_allowed {
461            return Err(MlsError::InvalidFutureGeneration(generation));
462        }
463
464        #[cfg(not(feature = "out_of_order"))]
465        while self.generation < generation {
466            self.next_message_key(cipher_suite_provider)?;
467        }
468
469        #[cfg(feature = "out_of_order")]
470        while self.generation < generation {
471            let key_data = self.next_message_key(cipher_suite_provider).await?;
472            self.history.insert(key_data.generation, key_data);
473        }
474
475        self.next_message_key(cipher_suite_provider).await
476    }
477
478    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
479    async fn next_message_key<P: CipherSuiteProvider>(
480        &mut self,
481        cipher_suite_provider: &P,
482    ) -> Result<MessageKeyData, MlsError> {
483        let generation = self.generation;
484
485        let key = MessageKeyData {
486            nonce: self
487                .derive_secret(
488                    cipher_suite_provider,
489                    b"nonce",
490                    cipher_suite_provider.aead_nonce_size(),
491                )
492                .await?,
493            key: self
494                .derive_secret(
495                    cipher_suite_provider,
496                    b"key",
497                    cipher_suite_provider.aead_key_size(),
498                )
499                .await?,
500            generation,
501        };
502
503        self.secret = self
504            .derive_secret(
505                cipher_suite_provider,
506                b"secret",
507                cipher_suite_provider.kdf_extract_size(),
508            )
509            .await?
510            .into();
511
512        self.generation = generation + 1;
513
514        Ok(key)
515    }
516
517    /// Peeks at the next key generation, but does not increment the generation nor
518    /// derive keys.
519    #[cfg(feature = "export_key_generation")]
520    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync, allow(dead_code))]
521    async fn peek_next_key_generation(&self) -> u32 {
522        self.generation
523    }
524
525    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
526    async fn derive_secret<P: CipherSuiteProvider>(
527        &self,
528        cipher_suite_provider: &P,
529        label: &[u8],
530        len: usize,
531    ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
532        kdf_expand_with_label(
533            cipher_suite_provider,
534            self.secret.as_ref(),
535            label,
536            &self.generation.to_be_bytes(),
537            Some(len),
538        )
539        .await
540        .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
541    }
542}
543
544#[cfg(test)]
545pub(crate) mod test_utils {
546    use alloc::{string::String, vec::Vec};
547    use mls_rs_core::crypto::CipherSuiteProvider;
548    use zeroize::Zeroizing;
549
550    use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
551
552    use super::{KeyType, SecretKeyRatchet, SecretTree};
553
554    pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
555        SecretTree::new(leaf_count, Zeroizing::new(secret))
556    }
557
558    impl SecretTree<u32> {
559        pub(crate) fn get_root_secret(&self) -> Vec<u8> {
560            self.known_secrets
561                .clone()
562                .take_node(&self.leaf_count.root())
563                .unwrap()
564                .into_secret()
565                .unwrap()
566                .to_vec()
567        }
568    }
569
570    #[derive(Debug, serde::Serialize, serde::Deserialize)]
571    pub struct RatchetInteropTestCase {
572        #[serde(with = "hex::serde")]
573        secret: Vec<u8>,
574        label: String,
575        generation: u32,
576        length: usize,
577        #[serde(with = "hex::serde")]
578        out: Vec<u8>,
579    }
580
581    #[derive(Debug, serde::Serialize, serde::Deserialize)]
582    pub struct InteropTestCase {
583        cipher_suite: u16,
584        derive_tree_secret: RatchetInteropTestCase,
585    }
586
587    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
588    async fn test_basic_crypto_test_vectors() {
589        let test_cases: Vec<InteropTestCase> =
590            load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
591
592        for test_case in test_cases {
593            if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
594                test_case.derive_tree_secret.verify(&cs).await
595            }
596        }
597    }
598
599    impl RatchetInteropTestCase {
600        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
601        pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
602            let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
603                .await
604                .unwrap();
605
606            ratchet.secret = self.secret.clone().into();
607            ratchet.generation = self.generation;
608
609            let computed = ratchet
610                .derive_secret(cs, self.label.as_bytes(), self.length)
611                .await
612                .unwrap();
613
614            assert_eq!(&computed.to_vec(), &self.out);
615        }
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use alloc::vec;
622
623    use crate::{
624        cipher_suite::CipherSuite,
625        client::test_utils::TEST_CIPHER_SUITE,
626        crypto::test_utils::{
627            test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
628        },
629        tree_kem::node::NodeIndex,
630    };
631
632    #[cfg(not(mls_build_async))]
633    use crate::group::test_utils::random_bytes;
634
635    use super::{test_utils::get_test_tree, *};
636
637    use assert_matches::assert_matches;
638
639    #[cfg(target_arch = "wasm32")]
640    use wasm_bindgen_test::wasm_bindgen_test as test;
641
642    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
643    async fn test_secret_tree() {
644        test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
645        test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
646    }
647
648    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
649    async fn test_secret_tree_custom<T: TreeIndex>(
650        leaf_count: T,
651        leaves_to_check: Vec<T>,
652        all_deleted: bool,
653    ) {
654        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
655            let cs_provider = test_cipher_suite_provider(cipher_suite);
656
657            let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
658            let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
659
660            let mut secrets = Vec::<SecretRatchets>::new();
661
662            for i in &leaves_to_check {
663                let secret = test_tree
664                    .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
665                    .await
666                    .unwrap();
667
668                secrets.push(secret);
669            }
670
671            // Verify the tree is now completely empty
672            assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
673
674            // Verify that all the secrets are unique
675            let count = secrets.len();
676            secrets.dedup();
677            assert_eq!(count, secrets.len());
678        }
679    }
680
681    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
682    async fn test_secret_key_ratchet() {
683        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
684            let provider = test_cipher_suite_provider(cipher_suite);
685
686            let mut app_ratchet = SecretKeyRatchet::new(
687                &provider,
688                &vec![0u8; provider.kdf_extract_size()],
689                KeyType::Application,
690            )
691            .await
692            .unwrap();
693
694            let mut handshake_ratchet = SecretKeyRatchet::new(
695                &provider,
696                &vec![0u8; provider.kdf_extract_size()],
697                KeyType::Handshake,
698            )
699            .await
700            .unwrap();
701
702            let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
703            let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
704            let app_keys = vec![app_key_one, app_key_two];
705
706            let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
707            let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
708            let handshake_keys = vec![handshake_key_one, handshake_key_two];
709
710            // Verify that the keys have different outcomes due to their different labels
711            assert_ne!(app_keys, handshake_keys);
712
713            // Verify that the keys at each generation are different
714            assert_ne!(handshake_keys[0], handshake_keys[1]);
715        }
716    }
717
718    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
719    async fn test_get_key() {
720        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
721            let provider = test_cipher_suite_provider(cipher_suite);
722
723            let mut ratchet = SecretKeyRatchet::new(
724                &test_cipher_suite_provider(cipher_suite),
725                &vec![0u8; provider.kdf_extract_size()],
726                KeyType::Application,
727            )
728            .await
729            .unwrap();
730
731            let mut ratchet_clone = ratchet.clone();
732
733            // This will generate keys 0 and 1 in ratchet_clone
734            let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
735            let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
736
737            // Going back in time should result in an error
738            let res = ratchet_clone.get_message_key(&provider, 0).await;
739            assert!(res.is_err());
740
741            // Calling get key should be the same as calling next until hitting the desired generation
742            let second_key = ratchet
743                .get_message_key(&provider, ratchet_clone.generation - 1)
744                .await
745                .unwrap();
746
747            assert_eq!(clone_2, second_key)
748        }
749    }
750
751    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
752    async fn test_secret_ratchet() {
753        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
754            let provider = test_cipher_suite_provider(cipher_suite);
755
756            let mut ratchet = SecretKeyRatchet::new(
757                &provider,
758                &vec![0u8; provider.kdf_extract_size()],
759                KeyType::Application,
760            )
761            .await
762            .unwrap();
763
764            let original_secret = ratchet.secret.clone();
765            let _ = ratchet.next_message_key(&provider).await.unwrap();
766            let new_secret = ratchet.secret;
767            assert_ne!(original_secret, new_secret)
768        }
769    }
770
771    #[cfg(feature = "out_of_order")]
772    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
773    async fn test_out_of_order_keys() {
774        let cipher_suite = TEST_CIPHER_SUITE;
775        let provider = test_cipher_suite_provider(cipher_suite);
776
777        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
778            .await
779            .unwrap();
780        let mut ratchet_clone = ratchet.clone();
781
782        // Ask for all the keys in order from the original ratchet
783        let mut ordered_keys = Vec::<MessageKeyData>::new();
784
785        for i in 0..=MAX_RATCHET_BACK_HISTORY {
786            ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
787        }
788
789        // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
790        let last_key = ratchet_clone
791            .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
792            .await
793            .unwrap();
794
795        assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
796
797        // Get all the other keys
798        let mut back_history_keys = Vec::<MessageKeyData>::new();
799
800        for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
801            back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
802        }
803
804        assert_eq!(
805            back_history_keys,
806            ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
807        );
808    }
809
810    #[cfg(not(feature = "out_of_order"))]
811    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
812    async fn out_of_order_keys_should_throw_error() {
813        let cipher_suite = TEST_CIPHER_SUITE;
814        let provider = test_cipher_suite_provider(cipher_suite);
815
816        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
817            .await
818            .unwrap();
819
820        ratchet.get_message_key(&provider, 10).await.unwrap();
821        let res = ratchet.get_message_key(&provider, 9).await;
822        assert_matches!(res, Err(MlsError::KeyMissing(9)))
823    }
824
825    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
826    async fn test_too_out_of_order() {
827        let cipher_suite = TEST_CIPHER_SUITE;
828        let provider = test_cipher_suite_provider(cipher_suite);
829
830        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
831            .await
832            .unwrap();
833
834        let res = ratchet
835            .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
836            .await;
837
838        let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
839
840        assert_matches!(
841            res,
842            Err(MlsError::InvalidFutureGeneration(invalid))
843            if invalid == invalid_generation
844        )
845    }
846
847    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
848    async fn double_hit_leaves_epoch_intact() {
849        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
850        let test_secret = vec![0u8; cs.kdf_extract_size()];
851        let mut test_tree = get_test_tree(test_secret, 4u32);
852        let key_type = KeyType::Application;
853
854        // We receive a ciphertext from leaf 2 (node 4)
855        test_tree
856            .message_key_generation(&cs, 4, key_type, 0)
857            .await
858            .unwrap();
859
860        // Due to a double hit we receive that ciphertext again
861        let res = test_tree.message_key_generation(&cs, 4, key_type, 0).await;
862        assert_matches!(res, Err(MlsError::KeyMissing(0)));
863
864        // We receive another ciphertext from leaf 2
865        test_tree
866            .message_key_generation(&cs, 4, key_type, 1)
867            .await
868            .unwrap();
869    }
870
871    #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
872    struct Ratchet {
873        application_keys: Vec<Vec<u8>>,
874        handshake_keys: Vec<Vec<u8>>,
875    }
876
877    #[derive(Debug, serde::Serialize, serde::Deserialize)]
878    struct TestCase {
879        cipher_suite: u16,
880        #[serde(with = "hex::serde")]
881        encryption_secret: Vec<u8>,
882        ratchets: Vec<Ratchet>,
883    }
884
885    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
886    async fn get_ratchet_data(
887        secret_tree: &mut SecretTree<NodeIndex>,
888        cipher_suite: CipherSuite,
889    ) -> Vec<Ratchet> {
890        let provider = test_cipher_suite_provider(cipher_suite);
891        let mut ratchet_data = Vec::new();
892
893        for index in 0..16 {
894            let mut ratchets = secret_tree
895                .take_leaf_ratchet(&provider, &(index * 2))
896                .await
897                .unwrap();
898
899            let mut application_keys = Vec::new();
900
901            for _ in 0..20 {
902                let key = ratchets
903                    .handshake
904                    .next_message_key(&provider)
905                    .await
906                    .unwrap()
907                    .mls_encode_to_vec()
908                    .unwrap();
909
910                application_keys.push(key);
911            }
912
913            let mut handshake_keys = Vec::new();
914
915            for _ in 0..20 {
916                let key = ratchets
917                    .handshake
918                    .next_message_key(&provider)
919                    .await
920                    .unwrap()
921                    .mls_encode_to_vec()
922                    .unwrap();
923
924                handshake_keys.push(key);
925            }
926
927            ratchet_data.push(Ratchet {
928                application_keys,
929                handshake_keys,
930            });
931        }
932
933        ratchet_data
934    }
935
936    #[cfg(not(mls_build_async))]
937    #[cfg_attr(coverage_nightly, coverage(off))]
938    fn generate_test_vector() -> Vec<TestCase> {
939        CipherSuite::all()
940            .map(|cipher_suite| {
941                let provider = test_cipher_suite_provider(cipher_suite);
942                let encryption_secret = random_bytes(provider.kdf_extract_size());
943
944                let mut secret_tree =
945                    SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
946
947                TestCase {
948                    cipher_suite: cipher_suite.into(),
949                    encryption_secret,
950                    ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
951                }
952            })
953            .collect()
954    }
955
956    #[cfg(mls_build_async)]
957    fn generate_test_vector() -> Vec<TestCase> {
958        panic!("Tests cannot be generated in async mode");
959    }
960
961    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
962    async fn test_secret_tree_test_vectors() {
963        let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
964
965        for case in test_cases {
966            let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
967                continue;
968            };
969
970            let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
971            let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
972
973            assert_eq!(ratchet_data, case.ratchets);
974        }
975    }
976
977    #[cfg(feature = "export_key_generation")]
978    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
979    async fn peek_next_key_generation() {
980        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
981            let provider = test_cipher_suite_provider(cipher_suite);
982
983            let app_ratchet = SecretKeyRatchet::new(
984                &provider,
985                &vec![0u8; provider.kdf_extract_size()],
986                KeyType::Application,
987            )
988            .await
989            .unwrap();
990
991            let handshake_ratchet = SecretKeyRatchet::new(
992                &provider,
993                &vec![0u8; provider.kdf_extract_size()],
994                KeyType::Handshake,
995            )
996            .await
997            .unwrap();
998
999            for mut ratchet in [app_ratchet, handshake_ratchet] {
1000                assert_eq!(ratchet.peek_next_key_generation(), 0);
1001                assert_eq!(ratchet.peek_next_key_generation(), 0);
1002
1003                let key_zero = ratchet.next_message_key(&provider).await.unwrap();
1004                assert_eq!(key_zero.generation, 0);
1005
1006                assert_eq!(ratchet.peek_next_key_generation(), 1);
1007                let _key_one = ratchet.next_message_key(&provider).await.unwrap();
1008
1009                let key_two = ratchet.next_message_key(&provider).await.unwrap();
1010                assert_eq!(key_two.generation, 2);
1011
1012                assert_eq!(ratchet.peek_next_key_generation(), 3);
1013            }
1014        }
1015    }
1016}
1017
1018#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
1019mod interop_tests {
1020    #[cfg(not(mls_build_async))]
1021    use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
1022    use zeroize::Zeroizing;
1023
1024    use crate::{
1025        crypto::test_utils::try_test_cipher_suite_provider,
1026        group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
1027    };
1028
1029    use super::SecretTree;
1030
1031    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1032    async fn interop_test_vector() {
1033        // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
1034        let test_cases = load_interop_test_cases();
1035
1036        for case in test_cases {
1037            let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
1038                continue;
1039            };
1040
1041            case.sender_data.verify(&cs).await;
1042
1043            let mut tree = SecretTree::new(
1044                case.leaves.len() as u32,
1045                Zeroizing::new(case.encryption_secret),
1046            );
1047
1048            for (index, leaves) in case.leaves.iter().enumerate() {
1049                for leaf in leaves.iter() {
1050                    let key = tree
1051                        .message_key_generation(
1052                            &cs,
1053                            (index as u32) * 2,
1054                            KeyType::Application,
1055                            leaf.generation,
1056                        )
1057                        .await
1058                        .unwrap();
1059
1060                    assert_eq!(key.key.to_vec(), leaf.application_key);
1061                    assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
1062
1063                    let key = tree
1064                        .message_key_generation(
1065                            &cs,
1066                            (index as u32) * 2,
1067                            KeyType::Handshake,
1068                            leaf.generation,
1069                        )
1070                        .await
1071                        .unwrap();
1072
1073                    assert_eq!(key.key.to_vec(), leaf.handshake_key);
1074                    assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
1075                }
1076            }
1077        }
1078    }
1079
1080    #[derive(Debug, serde::Serialize, serde::Deserialize)]
1081    struct InteropTestCase {
1082        cipher_suite: u16,
1083        #[serde(with = "hex::serde")]
1084        encryption_secret: Vec<u8>,
1085        sender_data: InteropSenderData,
1086        leaves: Vec<Vec<InteropLeaf>>,
1087    }
1088
1089    #[derive(Debug, serde::Serialize, serde::Deserialize)]
1090    struct InteropLeaf {
1091        generation: u32,
1092        #[serde(with = "hex::serde")]
1093        application_key: Vec<u8>,
1094        #[serde(with = "hex::serde")]
1095        application_nonce: Vec<u8>,
1096        #[serde(with = "hex::serde")]
1097        handshake_key: Vec<u8>,
1098        #[serde(with = "hex::serde")]
1099        handshake_nonce: Vec<u8>,
1100    }
1101
1102    fn load_interop_test_cases() -> Vec<InteropTestCase> {
1103        load_test_case_json!(secret_tree_interop, generate_test_vector())
1104    }
1105
1106    #[cfg(not(mls_build_async))]
1107    #[cfg_attr(coverage_nightly, coverage(off))]
1108    fn generate_test_vector() -> Vec<InteropTestCase> {
1109        let mut test_cases = vec![];
1110
1111        for cs in CipherSuite::all() {
1112            let Some(cs) = try_test_cipher_suite_provider(*cs) else {
1113                continue;
1114            };
1115
1116            let gens = [0, 15];
1117            let tree_sizes = [1, 8, 32];
1118
1119            for n_leaves in tree_sizes {
1120                let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
1121
1122                let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
1123
1124                let leaves = (0..n_leaves)
1125                    .map(|leaf| {
1126                        gens.into_iter()
1127                            .map(|gen| {
1128                                let index = leaf * 2u32;
1129
1130                                let handshake_key = tree
1131                                    .message_key_generation(&cs, index, KeyType::Handshake, gen)
1132                                    .unwrap();
1133
1134                                let app_key = tree
1135                                    .message_key_generation(&cs, index, KeyType::Application, gen)
1136                                    .unwrap();
1137
1138                                InteropLeaf {
1139                                    generation: gen,
1140                                    application_key: app_key.key.to_vec(),
1141                                    application_nonce: app_key.nonce.to_vec(),
1142                                    handshake_key: handshake_key.key.to_vec(),
1143                                    handshake_nonce: handshake_key.nonce.to_vec(),
1144                                }
1145                            })
1146                            .collect()
1147                    })
1148                    .collect();
1149
1150                let case = InteropTestCase {
1151                    cipher_suite: *cs.cipher_suite(),
1152                    encryption_secret,
1153                    sender_data: InteropSenderData::new(&cs),
1154                    leaves,
1155                };
1156
1157                test_cases.push(case);
1158            }
1159        }
1160
1161        test_cases
1162    }
1163
1164    #[cfg(mls_build_async)]
1165    fn generate_test_vector() -> Vec<InteropTestCase> {
1166        panic!("Tests cannot be generated in async mode");
1167    }
1168}