mls_rs/
psk.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::vec::Vec;
6
7#[cfg(any(test, feature = "external_client"))]
8use alloc::vec;
9
10use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
11
12#[cfg(any(test, feature = "external_client"))]
13use mls_rs_core::psk::PreSharedKeyStorage;
14
15#[cfg(any(test, feature = "external_client"))]
16use core::convert::Infallible;
17use core::fmt::{self, Debug};
18
19#[cfg(feature = "psk")]
20use crate::{client::MlsError, CipherSuiteProvider};
21
22#[cfg(feature = "psk")]
23use mls_rs_core::error::IntoAnyError;
24
25#[cfg(feature = "psk")]
26pub(crate) mod resolver;
27pub(crate) mod secret;
28
29pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
30
31#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
32#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub(crate) struct PreSharedKeyID {
35    pub key_id: JustPreSharedKeyID,
36    pub psk_nonce: PskNonce,
37}
38
39impl PreSharedKeyID {
40    #[cfg(feature = "psk")]
41    pub(crate) fn new<P: CipherSuiteProvider>(
42        key_id: JustPreSharedKeyID,
43        cs: &P,
44    ) -> Result<Self, MlsError> {
45        Ok(Self {
46            key_id,
47            psk_nonce: PskNonce::random(cs)
48                .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
49        })
50    }
51}
52
53#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
54#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56#[repr(u8)]
57pub(crate) enum JustPreSharedKeyID {
58    External(ExternalPskId) = 1u8,
59    Resumption(ResumptionPsk) = 2u8,
60}
61
62#[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
63#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
64#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65pub(crate) struct PskGroupId(
66    #[mls_codec(with = "mls_rs_codec::byte_vec")]
67    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
68    pub Vec<u8>,
69);
70
71impl Debug for PskGroupId {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        mls_rs_core::debug::pretty_bytes(&self.0)
74            .named("PskGroupId")
75            .fmt(f)
76    }
77}
78
79#[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
80#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
82pub(crate) struct PskNonce(
83    #[mls_codec(with = "mls_rs_codec::byte_vec")]
84    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
85    pub Vec<u8>,
86);
87
88impl Debug for PskNonce {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        mls_rs_core::debug::pretty_bytes(&self.0)
91            .named("PskNonce")
92            .fmt(f)
93    }
94}
95
96#[cfg(feature = "psk")]
97impl PskNonce {
98    pub fn random<P: CipherSuiteProvider>(
99        cipher_suite_provider: &P,
100    ) -> Result<Self, <P as CipherSuiteProvider>::Error> {
101        Ok(Self(cipher_suite_provider.random_bytes_vec(
102            cipher_suite_provider.kdf_extract_size(),
103        )?))
104    }
105}
106
107#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
108#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
109#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
110pub(crate) struct ResumptionPsk {
111    pub usage: ResumptionPSKUsage,
112    pub psk_group_id: PskGroupId,
113    pub psk_epoch: u64,
114}
115
116#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)]
117#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
118#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
119#[repr(u8)]
120pub(crate) enum ResumptionPSKUsage {
121    Application = 1u8,
122    Reinit = 2u8,
123    Branch = 3u8,
124}
125
126#[cfg(feature = "psk")]
127#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
128struct PSKLabel<'a> {
129    id: &'a PreSharedKeyID,
130    index: u16,
131    count: u16,
132}
133
134#[cfg(any(test, feature = "external_client"))]
135#[derive(Clone, Copy, Debug)]
136pub(crate) struct AlwaysFoundPskStorage;
137
138#[cfg(any(test, feature = "external_client"))]
139#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
140#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
141impl PreSharedKeyStorage for AlwaysFoundPskStorage {
142    type Error = Infallible;
143
144    async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
145        Ok(Some(vec![].into()))
146    }
147}
148
149#[cfg(feature = "psk")]
150#[cfg(test)]
151pub(crate) mod test_utils {
152    use crate::crypto::test_utils::test_cipher_suite_provider;
153
154    use super::PskNonce;
155    use mls_rs_core::crypto::CipherSuite;
156
157    #[cfg(not(mls_build_async))]
158    use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
159
160    #[cfg_attr(coverage_nightly, coverage(off))]
161    #[cfg(not(mls_build_async))]
162    pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>(
163        cipher_suite_provider: &P,
164    ) -> ExternalPskId {
165        ExternalPskId::new(
166            cipher_suite_provider
167                .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
168                .unwrap(),
169        )
170    }
171
172    pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce {
173        PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap()
174    }
175}
176
177#[cfg(feature = "psk")]
178#[cfg(test)]
179mod tests {
180    use crate::crypto::test_utils::TestCryptoProvider;
181    use core::iter;
182
183    #[cfg(target_arch = "wasm32")]
184    use wasm_bindgen_test::wasm_bindgen_test as test;
185
186    use super::test_utils::make_nonce;
187
188    #[test]
189    fn random_generation_of_nonces_is_random() {
190        let good = TestCryptoProvider::all_supported_cipher_suites()
191            .into_iter()
192            .all(|cipher_suite| {
193                let nonce = make_nonce(cipher_suite);
194                iter::repeat_with(|| make_nonce(cipher_suite))
195                    .take(1000)
196                    .all(|other| other != nonce)
197            });
198
199        assert!(good);
200    }
201}