mls_rs/
hash_reference.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::{
6    fmt::{self, Debug},
7    ops::Deref,
8};
9
10use crate::client::MlsError;
11use crate::CipherSuiteProvider;
12use alloc::vec::Vec;
13use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
14use mls_rs_core::error::IntoAnyError;
15
16#[derive(MlsSize, MlsEncode)]
17struct RefHashInput<'a> {
18    #[mls_codec(with = "mls_rs_codec::byte_vec")]
19    pub label: &'a [u8],
20    #[mls_codec(with = "mls_rs_codec::byte_vec")]
21    pub value: &'a [u8],
22}
23
24impl Debug for RefHashInput<'_> {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.debug_struct("RefHashInput")
27            .field("label", &mls_rs_core::debug::pretty_bytes(self.label))
28            .field("value", &mls_rs_core::debug::pretty_bytes(self.value))
29            .finish()
30    }
31}
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)]
34#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36pub struct HashReference(
37    #[mls_codec(with = "mls_rs_codec::byte_vec")]
38    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
39    Vec<u8>,
40);
41
42impl Debug for HashReference {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        mls_rs_core::debug::pretty_bytes(&self.0)
45            .named("HashReference")
46            .fmt(f)
47    }
48}
49
50impl Deref for HashReference {
51    type Target = [u8];
52
53    fn deref(&self) -> &Self::Target {
54        &self.0
55    }
56}
57
58impl AsRef<[u8]> for HashReference {
59    fn as_ref(&self) -> &[u8] {
60        &self.0
61    }
62}
63
64impl From<Vec<u8>> for HashReference {
65    fn from(val: Vec<u8>) -> Self {
66        Self(val)
67    }
68}
69
70impl HashReference {
71    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
72    pub async fn compute<P: CipherSuiteProvider>(
73        value: &[u8],
74        label: &[u8],
75        cipher_suite: &P,
76    ) -> Result<HashReference, MlsError> {
77        let input = RefHashInput { label, value };
78        let input_bytes = input.mls_encode_to_vec()?;
79
80        cipher_suite
81            .hash(&input_bytes)
82            .await
83            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
84            .map(HashReference)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use crate::crypto::test_utils::try_test_cipher_suite_provider;
91
92    #[cfg(not(mls_build_async))]
93    use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
94
95    use super::*;
96    use alloc::string::String;
97    use serde::{Deserialize, Serialize};
98
99    #[cfg(not(mls_build_async))]
100    use alloc::string::ToString;
101
102    #[cfg(target_arch = "wasm32")]
103    use wasm_bindgen_test::wasm_bindgen_test as test;
104
105    #[derive(Debug, Deserialize, Serialize)]
106    struct HashRefTestCase {
107        label: String,
108        #[serde(with = "hex::serde")]
109        value: Vec<u8>,
110        #[serde(with = "hex::serde")]
111        out: Vec<u8>,
112    }
113
114    #[derive(Debug, serde::Serialize, serde::Deserialize)]
115    pub struct InteropTestCase {
116        cipher_suite: u16,
117        ref_hash: HashRefTestCase,
118    }
119
120    #[cfg(not(mls_build_async))]
121    #[cfg_attr(coverage_nightly, coverage(off))]
122    fn generate_test_vector() -> Vec<InteropTestCase> {
123        CipherSuite::all()
124            .map(|cipher_suite| {
125                let provider = test_cipher_suite_provider(cipher_suite);
126
127                let input = b"test input";
128                let label = "test label";
129
130                let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap();
131
132                let ref_hash = HashRefTestCase {
133                    label: label.to_string(),
134                    value: input.to_vec(),
135                    out: output.to_vec(),
136                };
137
138                InteropTestCase {
139                    cipher_suite: cipher_suite.into(),
140                    ref_hash,
141                }
142            })
143            .collect()
144    }
145
146    #[cfg(mls_build_async)]
147    fn generate_test_vector() -> Vec<InteropTestCase> {
148        panic!("Tests cannot be generated in async mode");
149    }
150
151    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
152    async fn test_basic_crypto_test_vectors() {
153        // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
154        let test_cases: Vec<InteropTestCase> =
155            load_test_case_json!(basic_crypto, generate_test_vector());
156
157        for test_case in test_cases {
158            if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
159                let label = test_case.ref_hash.label.as_bytes();
160                let value = &test_case.ref_hash.value;
161                let computed = HashReference::compute(value, label, &cs).await.unwrap();
162                assert_eq!(&*computed, &test_case.ref_hash.out);
163            }
164        }
165    }
166}