1use crate::cipher_suite::CipherSuite;
6use crate::client::MlsError;
7use crate::crypto::HpkePublicKey;
8use crate::hash_reference::HashReference;
9use crate::identity::SigningIdentity;
10use crate::protocol_version::ProtocolVersion;
11use crate::signer::Signable;
12use crate::time::MlsTime;
13use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
14use crate::CipherSuiteProvider;
15use alloc::vec::Vec;
16use core::{
17 fmt::{self, Debug},
18 ops::Deref,
19};
20use mls_rs_codec::MlsDecode;
21use mls_rs_codec::MlsEncode;
22use mls_rs_codec::MlsSize;
23use mls_rs_core::extension::ExtensionList;
24
25mod validator;
26pub(crate) use validator::*;
27
28pub(crate) mod generator;
29pub(crate) use generator::*;
30
31#[non_exhaustive]
32#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34#[cfg_attr(
35 all(feature = "ffi", not(test)),
36 safer_ffi_gen::ffi_type(clone, opaque)
37)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub struct KeyPackage {
40 pub version: ProtocolVersion,
41 pub cipher_suite: CipherSuite,
42 pub hpke_init_key: HpkePublicKey,
43 pub(crate) leaf_node: LeafNode,
44 pub extensions: ExtensionList,
45 #[mls_codec(with = "mls_rs_codec::byte_vec")]
46 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
47 pub signature: Vec<u8>,
48}
49
50impl Debug for KeyPackage {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("KeyPackage")
53 .field("version", &self.version)
54 .field("cipher_suite", &self.cipher_suite)
55 .field("hpke_init_key", &self.hpke_init_key)
56 .field("leaf_node", &self.leaf_node)
57 .field("extensions", &self.extensions)
58 .field(
59 "signature",
60 &mls_rs_core::debug::pretty_bytes(&self.signature),
61 )
62 .finish()
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
67#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
68#[cfg_attr(
69 all(feature = "ffi", not(test)),
70 safer_ffi_gen::ffi_type(clone, opaque)
71)]
72pub struct KeyPackageRef(HashReference);
73
74impl Deref for KeyPackageRef {
75 type Target = [u8];
76
77 fn deref(&self) -> &Self::Target {
78 &self.0
79 }
80}
81
82impl From<Vec<u8>> for KeyPackageRef {
83 fn from(v: Vec<u8>) -> Self {
84 Self(HashReference::from(v))
85 }
86}
87
88#[derive(MlsSize, MlsEncode)]
89struct KeyPackageData<'a> {
90 pub version: ProtocolVersion,
91 pub cipher_suite: CipherSuite,
92 #[mls_codec(with = "mls_rs_codec::byte_vec")]
93 pub hpke_init_key: &'a HpkePublicKey,
94 pub leaf_node: &'a LeafNode,
95 pub extensions: &'a ExtensionList,
96}
97
98#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
99impl KeyPackage {
100 #[cfg(feature = "ffi")]
101 pub fn version(&self) -> ProtocolVersion {
102 self.version
103 }
104
105 #[cfg(feature = "ffi")]
106 pub fn cipher_suite(&self) -> CipherSuite {
107 self.cipher_suite
108 }
109
110 pub fn signing_identity(&self) -> &SigningIdentity {
111 &self.leaf_node.signing_identity
112 }
113
114 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
115 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
116 pub async fn to_reference<CP: CipherSuiteProvider>(
117 &self,
118 cipher_suite_provider: &CP,
119 ) -> Result<KeyPackageRef, MlsError> {
120 if cipher_suite_provider.cipher_suite() != self.cipher_suite {
121 return Err(MlsError::CipherSuiteMismatch);
122 }
123
124 Ok(KeyPackageRef(
125 HashReference::compute(
126 &self.mls_encode_to_vec()?,
127 b"MLS 1.0 KeyPackage Reference",
128 cipher_suite_provider,
129 )
130 .await?,
131 ))
132 }
133
134 pub fn expiration(&self) -> Result<MlsTime, MlsError> {
136 if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
137 Ok(lifetime.not_after)
138 } else {
139 Err(MlsError::InvalidLeafNodeSource)
140 }
141 }
142}
143
144impl Signable<'_> for KeyPackage {
145 const SIGN_LABEL: &'static str = "KeyPackageTBS";
146
147 type SigningContext = ();
148
149 fn signature(&self) -> &[u8] {
150 &self.signature
151 }
152
153 fn signable_content(
154 &self,
155 _context: &Self::SigningContext,
156 ) -> Result<Vec<u8>, mls_rs_codec::Error> {
157 KeyPackageData {
158 version: self.version,
159 cipher_suite: self.cipher_suite,
160 hpke_init_key: &self.hpke_init_key,
161 leaf_node: &self.leaf_node,
162 extensions: &self.extensions,
163 }
164 .mls_encode_to_vec()
165 }
166
167 fn write_signature(&mut self, signature: Vec<u8>) {
168 self.signature = signature
169 }
170}
171
172#[cfg(test)]
173pub(crate) mod test_utils {
174 use super::*;
175 use crate::{
176 crypto::test_utils::test_cipher_suite_provider,
177 group::framing::MlsMessagePayload,
178 identity::test_utils::get_test_signing_identity,
179 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
180 MlsMessage,
181 };
182
183 use mls_rs_core::crypto::SignatureSecretKey;
184
185 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
186 pub(crate) async fn test_key_package(
187 protocol_version: ProtocolVersion,
188 cipher_suite: CipherSuite,
189 id: &str,
190 ) -> KeyPackage {
191 test_key_package_with_signer(protocol_version, cipher_suite, id)
192 .await
193 .0
194 }
195
196 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
197 pub(crate) async fn test_key_package_with_signer(
198 protocol_version: ProtocolVersion,
199 cipher_suite: CipherSuite,
200 id: &str,
201 ) -> (KeyPackage, SignatureSecretKey) {
202 let (signing_identity, secret_key) =
203 get_test_signing_identity(cipher_suite, id.as_bytes()).await;
204
205 let generator = KeyPackageGenerator {
206 protocol_version,
207 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
208 signing_identity: &signing_identity,
209 signing_key: &secret_key,
210 };
211
212 let key_package = generator
213 .generate(
214 Lifetime::years(1, None).unwrap(),
215 get_test_capabilities(),
216 ExtensionList::default(),
217 ExtensionList::default(),
218 )
219 .await
220 .unwrap()
221 .key_package;
222
223 (key_package, secret_key)
224 }
225
226 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
227 pub(crate) async fn test_key_package_message(
228 protocol_version: ProtocolVersion,
229 cipher_suite: CipherSuite,
230 id: &str,
231 ) -> MlsMessage {
232 MlsMessage::new(
233 protocol_version,
234 MlsMessagePayload::KeyPackage(
235 test_key_package(protocol_version, cipher_suite, id).await,
236 ),
237 )
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use crate::{
244 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
245 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
246 };
247
248 use super::{test_utils::test_key_package, *};
249 use alloc::format;
250 use assert_matches::assert_matches;
251
252 #[derive(serde::Deserialize, serde::Serialize)]
253 struct TestCase {
254 cipher_suite: u16,
255 #[serde(with = "hex::serde")]
256 input: Vec<u8>,
257 #[serde(with = "hex::serde")]
258 output: Vec<u8>,
259 }
260
261 impl TestCase {
262 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
263 #[cfg_attr(coverage_nightly, coverage(off))]
264 async fn generate() -> Vec<TestCase> {
265 let mut test_cases = Vec::new();
266
267 for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
268 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
269 .enumerate()
270 {
271 let pkg =
272 test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
273
274 let pkg_ref = pkg
275 .to_reference(&test_cipher_suite_provider(cipher_suite))
276 .await
277 .unwrap();
278
279 let case = TestCase {
280 cipher_suite: cipher_suite.into(),
281 input: pkg.mls_encode_to_vec().unwrap(),
282 output: pkg_ref.to_vec(),
283 };
284
285 test_cases.push(case);
286 }
287
288 test_cases
289 }
290 }
291
292 #[cfg(mls_build_async)]
293 async fn load_test_cases() -> Vec<TestCase> {
294 load_test_case_json!(key_package_ref, TestCase::generate().await)
295 }
296
297 #[cfg(not(mls_build_async))]
298 fn load_test_cases() -> Vec<TestCase> {
299 load_test_case_json!(key_package_ref, TestCase::generate())
300 }
301
302 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
303 async fn test_key_package_ref() {
304 let cases = load_test_cases().await;
305
306 for one_case in cases {
307 let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
308 continue;
309 };
310
311 let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
312
313 let key_package_ref = key_package.to_reference(&provider).await.unwrap();
314
315 let expected_out = KeyPackageRef::from(one_case.output);
316 assert_eq!(expected_out, key_package_ref);
317 }
318 }
319
320 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
321 async fn key_package_ref_fails_invalid_cipher_suite() {
322 let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
323
324 for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
325 if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
326 let res = key_package.to_reference(&cs).await;
327
328 assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
329 }
330 }
331 }
332}