1use core::{
6 fmt::{self, Debug},
7 ops::Deref,
8};
9
10use crate::client::MlsError;
11use crate::CipherSuiteProvider;
12use alloc::vec::Vec;
13use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
14use mls_rs_core::error::IntoAnyError;
15
16#[derive(MlsSize, MlsEncode)]
17struct RefHashInput<'a> {
18 #[mls_codec(with = "mls_rs_codec::byte_vec")]
19 pub label: &'a [u8],
20 #[mls_codec(with = "mls_rs_codec::byte_vec")]
21 pub value: &'a [u8],
22}
23
24impl Debug for RefHashInput<'_> {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.debug_struct("RefHashInput")
27 .field("label", &mls_rs_core::debug::pretty_bytes(self.label))
28 .field("value", &mls_rs_core::debug::pretty_bytes(self.value))
29 .finish()
30 }
31}
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)]
34#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36pub struct HashReference(
37 #[mls_codec(with = "mls_rs_codec::byte_vec")]
38 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
39 Vec<u8>,
40);
41
42impl Debug for HashReference {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 mls_rs_core::debug::pretty_bytes(&self.0)
45 .named("HashReference")
46 .fmt(f)
47 }
48}
49
50impl Deref for HashReference {
51 type Target = [u8];
52
53 fn deref(&self) -> &Self::Target {
54 &self.0
55 }
56}
57
58impl AsRef<[u8]> for HashReference {
59 fn as_ref(&self) -> &[u8] {
60 &self.0
61 }
62}
63
64impl From<Vec<u8>> for HashReference {
65 fn from(val: Vec<u8>) -> Self {
66 Self(val)
67 }
68}
69
70impl HashReference {
71 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
72 pub async fn compute<P: CipherSuiteProvider>(
73 value: &[u8],
74 label: &[u8],
75 cipher_suite: &P,
76 ) -> Result<HashReference, MlsError> {
77 let input = RefHashInput { label, value };
78 let input_bytes = input.mls_encode_to_vec()?;
79
80 cipher_suite
81 .hash(&input_bytes)
82 .await
83 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
84 .map(HashReference)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use crate::crypto::test_utils::try_test_cipher_suite_provider;
91
92 #[cfg(not(mls_build_async))]
93 use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
94
95 use super::*;
96 use alloc::string::String;
97 use serde::{Deserialize, Serialize};
98
99 #[cfg(not(mls_build_async))]
100 use alloc::string::ToString;
101
102 #[cfg(target_arch = "wasm32")]
103 use wasm_bindgen_test::wasm_bindgen_test as test;
104
105 #[derive(Debug, Deserialize, Serialize)]
106 struct HashRefTestCase {
107 label: String,
108 #[serde(with = "hex::serde")]
109 value: Vec<u8>,
110 #[serde(with = "hex::serde")]
111 out: Vec<u8>,
112 }
113
114 #[derive(Debug, serde::Serialize, serde::Deserialize)]
115 pub struct InteropTestCase {
116 cipher_suite: u16,
117 ref_hash: HashRefTestCase,
118 }
119
120 #[cfg(not(mls_build_async))]
121 #[cfg_attr(coverage_nightly, coverage(off))]
122 fn generate_test_vector() -> Vec<InteropTestCase> {
123 CipherSuite::all()
124 .map(|cipher_suite| {
125 let provider = test_cipher_suite_provider(cipher_suite);
126
127 let input = b"test input";
128 let label = "test label";
129
130 let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap();
131
132 let ref_hash = HashRefTestCase {
133 label: label.to_string(),
134 value: input.to_vec(),
135 out: output.to_vec(),
136 };
137
138 InteropTestCase {
139 cipher_suite: cipher_suite.into(),
140 ref_hash,
141 }
142 })
143 .collect()
144 }
145
146 #[cfg(mls_build_async)]
147 fn generate_test_vector() -> Vec<InteropTestCase> {
148 panic!("Tests cannot be generated in async mode");
149 }
150
151 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
152 async fn test_basic_crypto_test_vectors() {
153 let test_cases: Vec<InteropTestCase> =
155 load_test_case_json!(basic_crypto, generate_test_vector());
156
157 for test_case in test_cases {
158 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
159 let label = test_case.ref_hash.label.as_bytes();
160 let value = &test_case.ref_hash.value;
161 let computed = HashReference::compute(value, label, &cs).await.unwrap();
162 assert_eq!(&*computed, &test_case.ref_hash.out);
163 }
164 }
165 }
166}