1use crate::cipher_suite::CipherSuite;
6use crate::client::MlsError;
7use crate::crypto::HpkePublicKey;
8use crate::hash_reference::HashReference;
9use crate::identity::SigningIdentity;
10use crate::protocol_version::ProtocolVersion;
11use crate::signer::Signable;
12use crate::time::MlsTime;
13use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
14use crate::CipherSuiteProvider;
15use alloc::vec::Vec;
16use core::{
17 fmt::{self, Debug},
18 ops::Deref,
19};
20use mls_rs_codec::MlsDecode;
21use mls_rs_codec::MlsEncode;
22use mls_rs_codec::MlsSize;
23use mls_rs_core::extension::ExtensionList;
24
25mod validator;
26pub(crate) use validator::*;
27
28pub(crate) mod generator;
29pub(crate) use generator::*;
30
31#[non_exhaustive]
32#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35pub struct KeyPackage {
36 pub version: ProtocolVersion,
37 pub cipher_suite: CipherSuite,
38 pub hpke_init_key: HpkePublicKey,
39 pub(crate) leaf_node: LeafNode,
40 pub extensions: ExtensionList,
41 #[mls_codec(with = "mls_rs_codec::byte_vec")]
42 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
43 pub signature: Vec<u8>,
44}
45
46impl Debug for KeyPackage {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("KeyPackage")
49 .field("version", &self.version)
50 .field("cipher_suite", &self.cipher_suite)
51 .field("hpke_init_key", &self.hpke_init_key)
52 .field("leaf_node", &self.leaf_node)
53 .field("extensions", &self.extensions)
54 .field(
55 "signature",
56 &mls_rs_core::debug::pretty_bytes(&self.signature),
57 )
58 .finish()
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
63#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
64pub struct KeyPackageRef(HashReference);
65
66impl Deref for KeyPackageRef {
67 type Target = [u8];
68
69 fn deref(&self) -> &Self::Target {
70 &self.0
71 }
72}
73
74impl From<Vec<u8>> for KeyPackageRef {
75 fn from(v: Vec<u8>) -> Self {
76 Self(HashReference::from(v))
77 }
78}
79
80#[derive(MlsSize, MlsEncode)]
81struct KeyPackageData<'a> {
82 pub version: ProtocolVersion,
83 pub cipher_suite: CipherSuite,
84 #[mls_codec(with = "mls_rs_codec::byte_vec")]
85 pub hpke_init_key: &'a HpkePublicKey,
86 pub leaf_node: &'a LeafNode,
87 pub extensions: &'a ExtensionList,
88}
89
90impl KeyPackage {
91 pub fn version(&self) -> ProtocolVersion {
92 self.version
93 }
94
95 pub fn cipher_suite(&self) -> CipherSuite {
96 self.cipher_suite
97 }
98
99 pub fn signing_identity(&self) -> &SigningIdentity {
100 &self.leaf_node.signing_identity
101 }
102
103 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
104 pub async fn to_reference<CP: CipherSuiteProvider>(
105 &self,
106 cipher_suite_provider: &CP,
107 ) -> Result<KeyPackageRef, MlsError> {
108 if cipher_suite_provider.cipher_suite() != self.cipher_suite {
109 return Err(MlsError::CipherSuiteMismatch);
110 }
111
112 Ok(KeyPackageRef(
113 HashReference::compute(
114 &self.mls_encode_to_vec()?,
115 b"MLS 1.0 KeyPackage Reference",
116 cipher_suite_provider,
117 )
118 .await?,
119 ))
120 }
121
122 pub fn expiration(&self) -> Result<MlsTime, MlsError> {
124 if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
125 Ok(lifetime.not_after)
126 } else {
127 Err(MlsError::InvalidLeafNodeSource)
128 }
129 }
130}
131
132impl Signable<'_> for KeyPackage {
133 const SIGN_LABEL: &'static str = "KeyPackageTBS";
134
135 type SigningContext = ();
136
137 fn signature(&self) -> &[u8] {
138 &self.signature
139 }
140
141 fn signable_content(
142 &self,
143 _context: &Self::SigningContext,
144 ) -> Result<Vec<u8>, mls_rs_codec::Error> {
145 KeyPackageData {
146 version: self.version,
147 cipher_suite: self.cipher_suite,
148 hpke_init_key: &self.hpke_init_key,
149 leaf_node: &self.leaf_node,
150 extensions: &self.extensions,
151 }
152 .mls_encode_to_vec()
153 }
154
155 fn write_signature(&mut self, signature: Vec<u8>) {
156 self.signature = signature
157 }
158}
159
160#[cfg(test)]
161pub(crate) mod test_utils {
162 use super::*;
163 use crate::{
164 crypto::test_utils::test_cipher_suite_provider,
165 group::framing::MlsMessagePayload,
166 identity::test_utils::get_test_signing_identity,
167 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
168 MlsMessage,
169 };
170
171 use mls_rs_core::crypto::SignatureSecretKey;
172
173 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
174 pub(crate) async fn test_key_package(
175 protocol_version: ProtocolVersion,
176 cipher_suite: CipherSuite,
177 id: &str,
178 ) -> KeyPackage {
179 test_key_package_with_signer(protocol_version, cipher_suite, id)
180 .await
181 .0
182 }
183
184 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
185 pub(crate) async fn test_key_package_with_signer(
186 protocol_version: ProtocolVersion,
187 cipher_suite: CipherSuite,
188 id: &str,
189 ) -> (KeyPackage, SignatureSecretKey) {
190 let (signing_identity, secret_key) =
191 get_test_signing_identity(cipher_suite, id.as_bytes()).await;
192
193 let generator = KeyPackageGenerator {
194 protocol_version,
195 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
196 signing_identity: &signing_identity,
197 signing_key: &secret_key,
198 };
199
200 let key_package = generator
201 .generate(
202 Lifetime::years(1, None).unwrap(),
203 get_test_capabilities(),
204 ExtensionList::default(),
205 ExtensionList::default(),
206 )
207 .await
208 .unwrap()
209 .key_package;
210
211 (key_package, secret_key)
212 }
213
214 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
215 pub(crate) async fn test_key_package_message(
216 protocol_version: ProtocolVersion,
217 cipher_suite: CipherSuite,
218 id: &str,
219 ) -> MlsMessage {
220 MlsMessage::new(
221 protocol_version,
222 MlsMessagePayload::KeyPackage(
223 test_key_package(protocol_version, cipher_suite, id).await,
224 ),
225 )
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use crate::{
232 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
233 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
234 };
235
236 use super::{test_utils::test_key_package, *};
237 use alloc::format;
238 use assert_matches::assert_matches;
239
240 #[derive(serde::Deserialize, serde::Serialize)]
241 struct TestCase {
242 cipher_suite: u16,
243 #[serde(with = "hex::serde")]
244 input: Vec<u8>,
245 #[serde(with = "hex::serde")]
246 output: Vec<u8>,
247 }
248
249 impl TestCase {
250 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
251 #[cfg_attr(coverage_nightly, coverage(off))]
252 async fn generate() -> Vec<TestCase> {
253 let mut test_cases = Vec::new();
254
255 for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
256 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
257 .enumerate()
258 {
259 let pkg =
260 test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
261
262 let pkg_ref = pkg
263 .to_reference(&test_cipher_suite_provider(cipher_suite))
264 .await
265 .unwrap();
266
267 let case = TestCase {
268 cipher_suite: cipher_suite.into(),
269 input: pkg.mls_encode_to_vec().unwrap(),
270 output: pkg_ref.to_vec(),
271 };
272
273 test_cases.push(case);
274 }
275
276 test_cases
277 }
278 }
279
280 #[cfg(mls_build_async)]
281 async fn load_test_cases() -> Vec<TestCase> {
282 load_test_case_json!(key_package_ref, TestCase::generate().await)
283 }
284
285 #[cfg(not(mls_build_async))]
286 fn load_test_cases() -> Vec<TestCase> {
287 load_test_case_json!(key_package_ref, TestCase::generate())
288 }
289
290 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
291 async fn test_key_package_ref() {
292 let cases = load_test_cases().await;
293
294 for one_case in cases {
295 let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
296 continue;
297 };
298
299 let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
300
301 let key_package_ref = key_package.to_reference(&provider).await.unwrap();
302
303 let expected_out = KeyPackageRef::from(one_case.output);
304 assert_eq!(expected_out, key_package_ref);
305 }
306 }
307
308 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
309 async fn key_package_ref_fails_invalid_cipher_suite() {
310 let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
311
312 for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
313 if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
314 let res = key_package.to_reference(&cs).await;
315
316 assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
317 }
318 }
319 }
320}