mls_rs/key_package/
mod.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::cipher_suite::CipherSuite;
6use crate::client::MlsError;
7use crate::crypto::HpkePublicKey;
8use crate::hash_reference::HashReference;
9use crate::identity::SigningIdentity;
10use crate::protocol_version::ProtocolVersion;
11use crate::signer::Signable;
12use crate::time::MlsTime;
13use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
14use crate::CipherSuiteProvider;
15use alloc::vec::Vec;
16use core::{
17    fmt::{self, Debug},
18    ops::Deref,
19};
20use mls_rs_codec::MlsDecode;
21use mls_rs_codec::MlsEncode;
22use mls_rs_codec::MlsSize;
23use mls_rs_core::extension::ExtensionList;
24
25mod validator;
26pub(crate) use validator::*;
27
28pub(crate) mod generator;
29pub(crate) use generator::*;
30
31#[non_exhaustive]
32#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34#[cfg_attr(
35    all(feature = "ffi", not(test)),
36    safer_ffi_gen::ffi_type(clone, opaque)
37)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub struct KeyPackage {
40    pub version: ProtocolVersion,
41    pub cipher_suite: CipherSuite,
42    pub hpke_init_key: HpkePublicKey,
43    pub(crate) leaf_node: LeafNode,
44    pub extensions: ExtensionList,
45    #[mls_codec(with = "mls_rs_codec::byte_vec")]
46    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
47    pub signature: Vec<u8>,
48}
49
50impl Debug for KeyPackage {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("KeyPackage")
53            .field("version", &self.version)
54            .field("cipher_suite", &self.cipher_suite)
55            .field("hpke_init_key", &self.hpke_init_key)
56            .field("leaf_node", &self.leaf_node)
57            .field("extensions", &self.extensions)
58            .field(
59                "signature",
60                &mls_rs_core::debug::pretty_bytes(&self.signature),
61            )
62            .finish()
63    }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
67#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
68#[cfg_attr(
69    all(feature = "ffi", not(test)),
70    safer_ffi_gen::ffi_type(clone, opaque)
71)]
72pub struct KeyPackageRef(HashReference);
73
74impl Deref for KeyPackageRef {
75    type Target = [u8];
76
77    fn deref(&self) -> &Self::Target {
78        &self.0
79    }
80}
81
82impl From<Vec<u8>> for KeyPackageRef {
83    fn from(v: Vec<u8>) -> Self {
84        Self(HashReference::from(v))
85    }
86}
87
88#[derive(MlsSize, MlsEncode)]
89struct KeyPackageData<'a> {
90    pub version: ProtocolVersion,
91    pub cipher_suite: CipherSuite,
92    #[mls_codec(with = "mls_rs_codec::byte_vec")]
93    pub hpke_init_key: &'a HpkePublicKey,
94    pub leaf_node: &'a LeafNode,
95    pub extensions: &'a ExtensionList,
96}
97
98#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
99impl KeyPackage {
100    #[cfg(feature = "ffi")]
101    pub fn version(&self) -> ProtocolVersion {
102        self.version
103    }
104
105    #[cfg(feature = "ffi")]
106    pub fn cipher_suite(&self) -> CipherSuite {
107        self.cipher_suite
108    }
109
110    pub fn signing_identity(&self) -> &SigningIdentity {
111        &self.leaf_node.signing_identity
112    }
113
114    #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
115    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
116    pub async fn to_reference<CP: CipherSuiteProvider>(
117        &self,
118        cipher_suite_provider: &CP,
119    ) -> Result<KeyPackageRef, MlsError> {
120        if cipher_suite_provider.cipher_suite() != self.cipher_suite {
121            return Err(MlsError::CipherSuiteMismatch);
122        }
123
124        Ok(KeyPackageRef(
125            HashReference::compute(
126                &self.mls_encode_to_vec()?,
127                b"MLS 1.0 KeyPackage Reference",
128                cipher_suite_provider,
129            )
130            .await?,
131        ))
132    }
133
134    /// Time after which the key package is expired.
135    pub fn expiration(&self) -> Result<MlsTime, MlsError> {
136        if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
137            Ok(lifetime.not_after)
138        } else {
139            Err(MlsError::InvalidLeafNodeSource)
140        }
141    }
142}
143
144impl Signable<'_> for KeyPackage {
145    const SIGN_LABEL: &'static str = "KeyPackageTBS";
146
147    type SigningContext = ();
148
149    fn signature(&self) -> &[u8] {
150        &self.signature
151    }
152
153    fn signable_content(
154        &self,
155        _context: &Self::SigningContext,
156    ) -> Result<Vec<u8>, mls_rs_codec::Error> {
157        KeyPackageData {
158            version: self.version,
159            cipher_suite: self.cipher_suite,
160            hpke_init_key: &self.hpke_init_key,
161            leaf_node: &self.leaf_node,
162            extensions: &self.extensions,
163        }
164        .mls_encode_to_vec()
165    }
166
167    fn write_signature(&mut self, signature: Vec<u8>) {
168        self.signature = signature
169    }
170}
171
172#[cfg(test)]
173pub(crate) mod test_utils {
174    use super::*;
175    use crate::{
176        crypto::test_utils::test_cipher_suite_provider,
177        group::framing::MlsMessagePayload,
178        identity::test_utils::get_test_signing_identity,
179        tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
180        MlsMessage,
181    };
182
183    use mls_rs_core::crypto::SignatureSecretKey;
184
185    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
186    pub(crate) async fn test_key_package(
187        protocol_version: ProtocolVersion,
188        cipher_suite: CipherSuite,
189        id: &str,
190    ) -> KeyPackage {
191        test_key_package_with_signer(protocol_version, cipher_suite, id)
192            .await
193            .0
194    }
195
196    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
197    pub(crate) async fn test_key_package_with_signer(
198        protocol_version: ProtocolVersion,
199        cipher_suite: CipherSuite,
200        id: &str,
201    ) -> (KeyPackage, SignatureSecretKey) {
202        let (signing_identity, secret_key) =
203            get_test_signing_identity(cipher_suite, id.as_bytes()).await;
204
205        let generator = KeyPackageGenerator {
206            protocol_version,
207            cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
208            signing_identity: &signing_identity,
209            signing_key: &secret_key,
210        };
211
212        let key_package = generator
213            .generate(
214                Lifetime::years(1, None).unwrap(),
215                get_test_capabilities(),
216                ExtensionList::default(),
217                ExtensionList::default(),
218            )
219            .await
220            .unwrap()
221            .key_package;
222
223        (key_package, secret_key)
224    }
225
226    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
227    pub(crate) async fn test_key_package_message(
228        protocol_version: ProtocolVersion,
229        cipher_suite: CipherSuite,
230        id: &str,
231    ) -> MlsMessage {
232        MlsMessage::new(
233            protocol_version,
234            MlsMessagePayload::KeyPackage(
235                test_key_package(protocol_version, cipher_suite, id).await,
236            ),
237        )
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use crate::{
244        client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
245        crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
246    };
247
248    use super::{test_utils::test_key_package, *};
249    use alloc::format;
250    use assert_matches::assert_matches;
251
252    #[derive(serde::Deserialize, serde::Serialize)]
253    struct TestCase {
254        cipher_suite: u16,
255        #[serde(with = "hex::serde")]
256        input: Vec<u8>,
257        #[serde(with = "hex::serde")]
258        output: Vec<u8>,
259    }
260
261    impl TestCase {
262        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
263        #[cfg_attr(coverage_nightly, coverage(off))]
264        async fn generate() -> Vec<TestCase> {
265            let mut test_cases = Vec::new();
266
267            for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
268                .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
269                .enumerate()
270            {
271                let pkg =
272                    test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
273
274                let pkg_ref = pkg
275                    .to_reference(&test_cipher_suite_provider(cipher_suite))
276                    .await
277                    .unwrap();
278
279                let case = TestCase {
280                    cipher_suite: cipher_suite.into(),
281                    input: pkg.mls_encode_to_vec().unwrap(),
282                    output: pkg_ref.to_vec(),
283                };
284
285                test_cases.push(case);
286            }
287
288            test_cases
289        }
290    }
291
292    #[cfg(mls_build_async)]
293    async fn load_test_cases() -> Vec<TestCase> {
294        load_test_case_json!(key_package_ref, TestCase::generate().await)
295    }
296
297    #[cfg(not(mls_build_async))]
298    fn load_test_cases() -> Vec<TestCase> {
299        load_test_case_json!(key_package_ref, TestCase::generate())
300    }
301
302    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
303    async fn test_key_package_ref() {
304        let cases = load_test_cases().await;
305
306        for one_case in cases {
307            let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
308                continue;
309            };
310
311            let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
312
313            let key_package_ref = key_package.to_reference(&provider).await.unwrap();
314
315            let expected_out = KeyPackageRef::from(one_case.output);
316            assert_eq!(expected_out, key_package_ref);
317        }
318    }
319
320    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
321    async fn key_package_ref_fails_invalid_cipher_suite() {
322        let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
323
324        for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
325            if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
326                let res = key_package.to_reference(&cs).await;
327
328                assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
329            }
330        }
331    }
332}