mls_rs/
signer.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::vec::Vec;
6use core::fmt::{self, Debug};
7use mls_rs_codec::{MlsEncode, MlsSize};
8use mls_rs_core::error::IntoAnyError;
9
10use crate::client::MlsError;
11use crate::crypto::{CipherSuiteProvider, SignaturePublicKey, SignatureSecretKey};
12
13#[derive(Clone, MlsSize, MlsEncode)]
14struct SignContent {
15    #[mls_codec(with = "mls_rs_codec::byte_vec")]
16    label: Vec<u8>,
17    #[mls_codec(with = "mls_rs_codec::byte_vec")]
18    content: Vec<u8>,
19}
20
21impl Debug for SignContent {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("SignContent")
24            .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
25            .field("content", &mls_rs_core::debug::pretty_bytes(&self.content))
26            .finish()
27    }
28}
29
30impl SignContent {
31    pub fn new(label: &str, content: Vec<u8>) -> Self {
32        Self {
33            label: [b"MLS 1.0 ", label.as_bytes()].concat(),
34            content,
35        }
36    }
37}
38
39#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
40#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
41#[cfg_attr(
42    all(not(target_arch = "wasm32"), mls_build_async),
43    maybe_async::must_be_async
44)]
45pub(crate) trait Signable<'a> {
46    const SIGN_LABEL: &'static str;
47
48    type SigningContext: Send + Sync;
49
50    fn signature(&self) -> &[u8];
51
52    fn signable_content(
53        &self,
54        context: &Self::SigningContext,
55    ) -> Result<Vec<u8>, mls_rs_codec::Error>;
56
57    fn write_signature(&mut self, signature: Vec<u8>);
58
59    async fn sign<P: CipherSuiteProvider>(
60        &mut self,
61        signature_provider: &P,
62        signer: &SignatureSecretKey,
63        context: &Self::SigningContext,
64    ) -> Result<(), MlsError> {
65        let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
66
67        let signature = signature_provider
68            .sign(signer, &sign_content.mls_encode_to_vec()?)
69            .await
70            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
71
72        self.write_signature(signature);
73
74        Ok(())
75    }
76
77    async fn verify<P: CipherSuiteProvider>(
78        &self,
79        signature_provider: &P,
80        public_key: &SignaturePublicKey,
81        context: &Self::SigningContext,
82    ) -> Result<(), MlsError> {
83        let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
84
85        signature_provider
86            .verify(
87                public_key,
88                self.signature(),
89                &sign_content.mls_encode_to_vec()?,
90            )
91            .await
92            .map_err(|_| MlsError::InvalidSignature)
93    }
94}
95
96#[cfg(test)]
97pub(crate) mod test_utils {
98    use alloc::vec;
99    use alloc::{string::String, vec::Vec};
100    use mls_rs_core::crypto::CipherSuiteProvider;
101
102    use crate::crypto::test_utils::try_test_cipher_suite_provider;
103
104    use super::Signable;
105
106    #[derive(Debug, serde::Serialize, serde::Deserialize)]
107    pub struct SignatureInteropTestCase {
108        #[serde(with = "hex::serde", rename = "priv")]
109        secret: Vec<u8>,
110        #[serde(with = "hex::serde", rename = "pub")]
111        public: Vec<u8>,
112        #[serde(with = "hex::serde")]
113        content: Vec<u8>,
114        label: String,
115        #[serde(with = "hex::serde")]
116        signature: Vec<u8>,
117    }
118
119    #[derive(Debug, serde::Serialize, serde::Deserialize)]
120    pub struct InteropTestCase {
121        cipher_suite: u16,
122        sign_with_label: SignatureInteropTestCase,
123    }
124
125    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
126    async fn test_basic_crypto_test_vectors() {
127        let test_cases: Vec<InteropTestCase> =
128            load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
129
130        for test_case in test_cases {
131            if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
132                test_case.sign_with_label.verify(&cs).await;
133            }
134        }
135    }
136
137    pub struct TestSignable {
138        pub content: Vec<u8>,
139        pub signature: Vec<u8>,
140    }
141
142    impl Signable<'_> for TestSignable {
143        const SIGN_LABEL: &'static str = "SignWithLabel";
144
145        type SigningContext = Vec<u8>;
146
147        fn signature(&self) -> &[u8] {
148            &self.signature
149        }
150
151        fn signable_content(
152            &self,
153            context: &Self::SigningContext,
154        ) -> Result<Vec<u8>, mls_rs_codec::Error> {
155            Ok([context.as_slice(), self.content.as_slice()].concat())
156        }
157
158        fn write_signature(&mut self, signature: Vec<u8>) {
159            self.signature = signature
160        }
161    }
162
163    impl SignatureInteropTestCase {
164        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
165        pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
166            let public = self.public.clone().into();
167
168            let signable = TestSignable {
169                content: self.content.clone(),
170                signature: self.signature.clone(),
171            };
172
173            signable.verify(cs, &public, &vec![]).await.unwrap();
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::{test_utils::TestSignable, *};
181    use crate::{
182        client::test_utils::TEST_CIPHER_SUITE,
183        crypto::test_utils::{
184            test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
185        },
186        group::test_utils::random_bytes,
187    };
188    use alloc::vec;
189    use assert_matches::assert_matches;
190
191    #[derive(Debug, serde::Serialize, serde::Deserialize)]
192    struct TestCase {
193        cipher_suite: u16,
194        #[serde(with = "hex::serde")]
195        content: Vec<u8>,
196        #[serde(with = "hex::serde")]
197        context: Vec<u8>,
198        #[serde(with = "hex::serde")]
199        signature: Vec<u8>,
200        #[serde(with = "hex::serde")]
201        signer: Vec<u8>,
202        #[serde(with = "hex::serde")]
203        public: Vec<u8>,
204    }
205
206    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
207    #[cfg_attr(coverage_nightly, coverage(off))]
208    async fn generate_test_cases() -> Vec<TestCase> {
209        let mut test_cases = Vec::new();
210
211        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
212            let provider = test_cipher_suite_provider(cipher_suite);
213
214            let (signer, public) = provider.signature_key_generate().await.unwrap();
215
216            let content = random_bytes(32);
217            let context = random_bytes(32);
218
219            let mut test_signable = TestSignable {
220                content: content.clone(),
221                signature: Vec::new(),
222            };
223
224            test_signable
225                .sign(&provider, &signer, &context)
226                .await
227                .unwrap();
228
229            test_cases.push(TestCase {
230                cipher_suite: cipher_suite.into(),
231                content,
232                context,
233                signature: test_signable.signature,
234                signer: signer.to_vec(),
235                public: public.to_vec(),
236            });
237        }
238
239        test_cases
240    }
241
242    #[cfg(mls_build_async)]
243    async fn load_test_cases() -> Vec<TestCase> {
244        load_test_case_json!(signatures, generate_test_cases().await)
245    }
246
247    #[cfg(not(mls_build_async))]
248    fn load_test_cases() -> Vec<TestCase> {
249        load_test_case_json!(signatures, generate_test_cases())
250    }
251
252    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
253    async fn test_signatures() {
254        let cases = load_test_cases().await;
255
256        for one_case in cases {
257            let Some(cipher_suite_provider) = try_test_cipher_suite_provider(one_case.cipher_suite)
258            else {
259                continue;
260            };
261
262            let public_key = SignaturePublicKey::from(one_case.public);
263
264            // Wasm uses incompatible signature secret key format
265            #[cfg(not(target_arch = "wasm32"))]
266            {
267                // Test signature generation
268                let mut test_signable = TestSignable {
269                    content: one_case.content.clone(),
270                    signature: Vec::new(),
271                };
272
273                let signature_key = SignatureSecretKey::from(one_case.signer);
274
275                test_signable
276                    .sign(&cipher_suite_provider, &signature_key, &one_case.context)
277                    .await
278                    .unwrap();
279
280                test_signable
281                    .verify(&cipher_suite_provider, &public_key, &one_case.context)
282                    .await
283                    .unwrap();
284            }
285
286            // Test verifying an existing signature
287            let test_signable = TestSignable {
288                content: one_case.content,
289                signature: one_case.signature,
290            };
291
292            test_signable
293                .verify(&cipher_suite_provider, &public_key, &one_case.context)
294                .await
295                .unwrap();
296        }
297    }
298
299    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
300    async fn test_invalid_signature() {
301        let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
302
303        let (correct_secret, _) = cipher_suite_provider
304            .signature_key_generate()
305            .await
306            .unwrap();
307        let (_, incorrect_public) = cipher_suite_provider
308            .signature_key_generate()
309            .await
310            .unwrap();
311
312        let mut test_signable = TestSignable {
313            content: random_bytes(32),
314            signature: vec![],
315        };
316
317        test_signable
318            .sign(&cipher_suite_provider, &correct_secret, &vec![])
319            .await
320            .unwrap();
321
322        let res = test_signable
323            .verify(&cipher_suite_provider, &incorrect_public, &vec![])
324            .await;
325
326        assert_matches!(res, Err(MlsError::InvalidSignature));
327    }
328
329    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
330    async fn test_invalid_context() {
331        let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
332
333        let (secret, public) = cipher_suite_provider
334            .signature_key_generate()
335            .await
336            .unwrap();
337
338        let correct_context = random_bytes(32);
339        let incorrect_context = random_bytes(32);
340
341        let mut test_signable = TestSignable {
342            content: random_bytes(32),
343            signature: vec![],
344        };
345
346        test_signable
347            .sign(&cipher_suite_provider, &secret, &correct_context)
348            .await
349            .unwrap();
350
351        let res = test_signable
352            .verify(&cipher_suite_provider, &public, &incorrect_context)
353            .await;
354
355        assert_matches!(res, Err(MlsError::InvalidSignature));
356    }
357}