Skip to main content

mls_rs/key_package/
mod.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::cipher_suite::CipherSuite;
6use crate::client::MlsError;
7use crate::crypto::HpkePublicKey;
8use crate::hash_reference::HashReference;
9use crate::identity::SigningIdentity;
10use crate::protocol_version::ProtocolVersion;
11use crate::signer::Signable;
12use crate::time::MlsTime;
13use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
14use crate::CipherSuiteProvider;
15use alloc::vec::Vec;
16use core::{
17    fmt::{self, Debug},
18    ops::Deref,
19};
20use mls_rs_codec::MlsDecode;
21use mls_rs_codec::MlsEncode;
22use mls_rs_codec::MlsSize;
23use mls_rs_core::extension::ExtensionList;
24
25mod validator;
26pub(crate) use validator::*;
27
28pub(crate) mod generator;
29pub(crate) use generator::*;
30
31#[non_exhaustive]
32#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35pub struct KeyPackage {
36    pub version: ProtocolVersion,
37    pub cipher_suite: CipherSuite,
38    pub hpke_init_key: HpkePublicKey,
39    pub(crate) leaf_node: LeafNode,
40    pub extensions: ExtensionList,
41    #[mls_codec(with = "mls_rs_codec::byte_vec")]
42    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
43    pub signature: Vec<u8>,
44}
45
46impl Debug for KeyPackage {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("KeyPackage")
49            .field("version", &self.version)
50            .field("cipher_suite", &self.cipher_suite)
51            .field("hpke_init_key", &self.hpke_init_key)
52            .field("leaf_node", &self.leaf_node)
53            .field("extensions", &self.extensions)
54            .field(
55                "signature",
56                &mls_rs_core::debug::pretty_bytes(&self.signature),
57            )
58            .finish()
59    }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
63#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
64pub struct KeyPackageRef(HashReference);
65
66impl Deref for KeyPackageRef {
67    type Target = [u8];
68
69    fn deref(&self) -> &Self::Target {
70        &self.0
71    }
72}
73
74impl From<Vec<u8>> for KeyPackageRef {
75    fn from(v: Vec<u8>) -> Self {
76        Self(HashReference::from(v))
77    }
78}
79
80#[derive(MlsSize, MlsEncode)]
81struct KeyPackageData<'a> {
82    pub version: ProtocolVersion,
83    pub cipher_suite: CipherSuite,
84    #[mls_codec(with = "mls_rs_codec::byte_vec")]
85    pub hpke_init_key: &'a HpkePublicKey,
86    pub leaf_node: &'a LeafNode,
87    pub extensions: &'a ExtensionList,
88}
89
90impl KeyPackage {
91    pub fn version(&self) -> ProtocolVersion {
92        self.version
93    }
94
95    pub fn cipher_suite(&self) -> CipherSuite {
96        self.cipher_suite
97    }
98
99    pub fn signing_identity(&self) -> &SigningIdentity {
100        &self.leaf_node.signing_identity
101    }
102
103    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
104    pub async fn to_reference<CP: CipherSuiteProvider>(
105        &self,
106        cipher_suite_provider: &CP,
107    ) -> Result<KeyPackageRef, MlsError> {
108        if cipher_suite_provider.cipher_suite() != self.cipher_suite {
109            return Err(MlsError::CipherSuiteMismatch);
110        }
111
112        Ok(KeyPackageRef(
113            HashReference::compute(
114                &self.mls_encode_to_vec()?,
115                b"MLS 1.0 KeyPackage Reference",
116                cipher_suite_provider,
117            )
118            .await?,
119        ))
120    }
121
122    /// Time after which the key package is expired.
123    pub fn expiration(&self) -> Result<MlsTime, MlsError> {
124        if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
125            Ok(lifetime.not_after)
126        } else {
127            Err(MlsError::InvalidLeafNodeSource)
128        }
129    }
130}
131
132impl Signable<'_> for KeyPackage {
133    const SIGN_LABEL: &'static str = "KeyPackageTBS";
134
135    type SigningContext = ();
136
137    fn signature(&self) -> &[u8] {
138        &self.signature
139    }
140
141    fn signable_content(
142        &self,
143        _context: &Self::SigningContext,
144    ) -> Result<Vec<u8>, mls_rs_codec::Error> {
145        KeyPackageData {
146            version: self.version,
147            cipher_suite: self.cipher_suite,
148            hpke_init_key: &self.hpke_init_key,
149            leaf_node: &self.leaf_node,
150            extensions: &self.extensions,
151        }
152        .mls_encode_to_vec()
153    }
154
155    fn write_signature(&mut self, signature: Vec<u8>) {
156        self.signature = signature
157    }
158}
159
160#[cfg(test)]
161pub(crate) mod test_utils {
162    use super::*;
163    use crate::{
164        crypto::test_utils::test_cipher_suite_provider,
165        group::framing::MlsMessagePayload,
166        identity::test_utils::get_test_signing_identity,
167        tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
168        MlsMessage,
169    };
170
171    use mls_rs_core::crypto::SignatureSecretKey;
172
173    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
174    pub(crate) async fn test_key_package(
175        protocol_version: ProtocolVersion,
176        cipher_suite: CipherSuite,
177        id: &str,
178    ) -> KeyPackage {
179        test_key_package_with_signer(protocol_version, cipher_suite, id)
180            .await
181            .0
182    }
183
184    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
185    pub(crate) async fn test_key_package_with_signer(
186        protocol_version: ProtocolVersion,
187        cipher_suite: CipherSuite,
188        id: &str,
189    ) -> (KeyPackage, SignatureSecretKey) {
190        let (signing_identity, secret_key) =
191            get_test_signing_identity(cipher_suite, id.as_bytes()).await;
192
193        let generator = KeyPackageGenerator {
194            protocol_version,
195            cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
196            signing_identity: &signing_identity,
197            signing_key: &secret_key,
198        };
199
200        let key_package = generator
201            .generate(
202                Lifetime::years(1, None).unwrap(),
203                get_test_capabilities(),
204                ExtensionList::default(),
205                ExtensionList::default(),
206            )
207            .await
208            .unwrap()
209            .key_package;
210
211        (key_package, secret_key)
212    }
213
214    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
215    pub(crate) async fn test_key_package_message(
216        protocol_version: ProtocolVersion,
217        cipher_suite: CipherSuite,
218        id: &str,
219    ) -> MlsMessage {
220        MlsMessage::new(
221            protocol_version,
222            MlsMessagePayload::KeyPackage(
223                test_key_package(protocol_version, cipher_suite, id).await,
224            ),
225        )
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::{
232        client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
233        crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
234    };
235
236    use super::{test_utils::test_key_package, *};
237    use alloc::format;
238    use assert_matches::assert_matches;
239
240    #[derive(serde::Deserialize, serde::Serialize)]
241    struct TestCase {
242        cipher_suite: u16,
243        #[serde(with = "hex::serde")]
244        input: Vec<u8>,
245        #[serde(with = "hex::serde")]
246        output: Vec<u8>,
247    }
248
249    impl TestCase {
250        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
251        #[cfg_attr(coverage_nightly, coverage(off))]
252        async fn generate() -> Vec<TestCase> {
253            let mut test_cases = Vec::new();
254
255            for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
256                .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
257                .enumerate()
258            {
259                let pkg =
260                    test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
261
262                let pkg_ref = pkg
263                    .to_reference(&test_cipher_suite_provider(cipher_suite))
264                    .await
265                    .unwrap();
266
267                let case = TestCase {
268                    cipher_suite: cipher_suite.into(),
269                    input: pkg.mls_encode_to_vec().unwrap(),
270                    output: pkg_ref.to_vec(),
271                };
272
273                test_cases.push(case);
274            }
275
276            test_cases
277        }
278    }
279
280    #[cfg(mls_build_async)]
281    async fn load_test_cases() -> Vec<TestCase> {
282        load_test_case_json!(key_package_ref, TestCase::generate().await)
283    }
284
285    #[cfg(not(mls_build_async))]
286    fn load_test_cases() -> Vec<TestCase> {
287        load_test_case_json!(key_package_ref, TestCase::generate())
288    }
289
290    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
291    async fn test_key_package_ref() {
292        let cases = load_test_cases().await;
293
294        for one_case in cases {
295            let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
296                continue;
297            };
298
299            let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
300
301            let key_package_ref = key_package.to_reference(&provider).await.unwrap();
302
303            let expected_out = KeyPackageRef::from(one_case.output);
304            assert_eq!(expected_out, key_package_ref);
305        }
306    }
307
308    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
309    async fn key_package_ref_fails_invalid_cipher_suite() {
310        let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
311
312        for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
313            if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
314                let res = key_package.to_reference(&cs).await;
315
316                assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
317            }
318        }
319    }
320}