1use borsh::BorshSerialize;
2use near_schema_checker_lib::ProtocolSchema;
3use serde::{Deserializer, Serializer};
4use sha2::Digest;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use std::io::Write;
8
9#[derive(
10 Copy,
11 Clone,
12 PartialEq,
13 Eq,
14 PartialOrd,
15 Ord,
16 derive_more::AsRef,
17 derive_more::AsMut,
18 arbitrary::Arbitrary,
19 borsh::BorshDeserialize,
20 borsh::BorshSerialize,
21 ProtocolSchema,
22)]
23#[as_ref(forward)]
24#[as_mut(forward)]
25pub struct CryptoHash(pub [u8; 32]);
26
27impl CryptoHash {
28 pub const LENGTH: usize = 32;
29
30 pub const fn new() -> Self {
31 Self([0; Self::LENGTH])
32 }
33
34 pub fn hash_bytes(bytes: &[u8]) -> CryptoHash {
36 CryptoHash(sha2::Sha256::digest(bytes).into())
37 }
38
39 pub fn hash_borsh<T: BorshSerialize>(value: T) -> CryptoHash {
46 let mut hasher = sha2::Sha256::default();
47 value.serialize(&mut hasher).unwrap();
48 CryptoHash(hasher.finalize().into())
49 }
50
51 pub fn hash_borsh_iter<I>(values: I) -> CryptoHash
59 where
60 I: IntoIterator,
61 I::IntoIter: ExactSizeIterator,
62 I::Item: BorshSerialize,
63 {
64 let iter = values.into_iter();
65 let n = u32::try_from(iter.len()).unwrap();
66 let mut hasher = sha2::Sha256::default();
67 hasher.write_all(&n.to_le_bytes()).unwrap();
68 let count =
69 iter.inspect(|value| BorshSerialize::serialize(&value, &mut hasher).unwrap()).count();
70 assert_eq!(n as usize, count);
71 CryptoHash(hasher.finalize().into())
72 }
73
74 pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
75 &self.0
76 }
77
78 fn to_base58_impl<Out>(self, visitor: impl FnOnce(&str) -> Out) -> Out {
84 let mut buffer = [0u8; 45];
88 let len = bs58::encode(self).into(&mut buffer[..]).unwrap();
89 let value = std::str::from_utf8(&buffer[..len]).unwrap();
90 visitor(value)
91 }
92
93 fn from_base58_impl(encoded: &str) -> Decode58Result {
99 let mut result = Self::new();
100 match bs58::decode(encoded).into(&mut result.0) {
101 Ok(len) if len == result.0.len() => Decode58Result::Ok(result),
102 Ok(_) | Err(bs58::decode::Error::BufferTooSmall) => Decode58Result::BadLength,
103 Err(err) => Decode58Result::Err(err),
104 }
105 }
106}
107
108#[cfg(feature = "schemars")]
109impl schemars::JsonSchema for CryptoHash {
110 fn schema_name() -> std::borrow::Cow<'static, str> {
111 "CryptoHash".to_string().into()
112 }
113
114 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
115 String::json_schema(generator)
116 }
117}
118
119enum Decode58Result {
121 Ok(CryptoHash),
123 BadLength,
125 Err(bs58::decode::Error),
128}
129
130impl Default for CryptoHash {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl serde::Serialize for CryptoHash {
137 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
138 where
139 S: Serializer,
140 {
141 self.to_base58_impl(|encoded| serializer.serialize_str(encoded))
142 }
143}
144
145struct Visitor;
150
151impl<'de> serde::de::Visitor<'de> for Visitor {
152 type Value = CryptoHash;
153
154 fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 fmt.write_str("base58-encoded 256-bit hash")
156 }
157
158 fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
159 match CryptoHash::from_base58_impl(s) {
160 Decode58Result::Ok(result) => Ok(result),
161 Decode58Result::BadLength => Err(E::invalid_length(s.len(), &self)),
162 Decode58Result::Err(err) => Err(E::custom(err)),
163 }
164 }
165}
166
167impl<'de> serde::Deserialize<'de> for CryptoHash {
168 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
169 where
170 D: Deserializer<'de>,
171 {
172 deserializer.deserialize_str(Visitor)
173 }
174}
175
176impl std::str::FromStr for CryptoHash {
177 type Err = Box<dyn std::error::Error + Send + Sync>;
178
179 fn from_str(encoded: &str) -> Result<Self, Self::Err> {
181 match Self::from_base58_impl(encoded) {
182 Decode58Result::Ok(result) => Ok(result),
183 Decode58Result::BadLength => Err("incorrect length for hash".into()),
184 Decode58Result::Err(err) => Err(err.into()),
185 }
186 }
187}
188
189impl TryFrom<&[u8]> for CryptoHash {
190 type Error = Box<dyn std::error::Error + Send + Sync>;
191
192 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
193 Ok(CryptoHash(bytes.try_into()?))
194 }
195}
196
197impl From<CryptoHash> for Vec<u8> {
198 fn from(hash: CryptoHash) -> Vec<u8> {
199 hash.0.to_vec()
200 }
201}
202
203impl From<&CryptoHash> for Vec<u8> {
204 fn from(hash: &CryptoHash) -> Vec<u8> {
205 hash.0.to_vec()
206 }
207}
208
209impl From<CryptoHash> for [u8; CryptoHash::LENGTH] {
210 fn from(hash: CryptoHash) -> [u8; CryptoHash::LENGTH] {
211 hash.0
212 }
213}
214
215impl fmt::Debug for CryptoHash {
216 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
217 fmt::Display::fmt(self, fmtr)
218 }
219}
220
221impl fmt::Display for CryptoHash {
222 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
223 self.to_base58_impl(|encoded| fmtr.write_str(encoded))
224 }
225}
226
227impl Hash for CryptoHash {
230 fn hash<H: Hasher>(&self, state: &mut H) {
231 state.write(self.as_ref());
232 }
233}
234
235pub fn hash(data: &[u8]) -> CryptoHash {
246 CryptoHash::hash_bytes(data)
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::str::FromStr;
253
254 #[derive(serde::Deserialize, serde::Serialize)]
255 struct Struct {
256 hash: CryptoHash,
257 }
258
259 #[test]
260 fn test_hash_borsh() {
261 fn value<T: BorshSerialize>(want: &str, value: T) {
262 assert_eq!(want, CryptoHash::hash_borsh(&value).to_string());
263 }
264
265 fn slice<T: BorshSerialize>(want: &str, slice: &[T]) {
266 assert_eq!(want, CryptoHash::hash_borsh(slice).to_string());
267 iter(want, slice.iter());
268 iter(want, slice);
269 }
270
271 fn iter<I>(want: &str, iter: I)
272 where
273 I: IntoIterator,
274 I::IntoIter: ExactSizeIterator,
275 I::Item: BorshSerialize,
276 {
277 assert_eq!(want, CryptoHash::hash_borsh_iter(iter).to_string());
278 }
279
280 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo");
281 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
282 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &b"foo"[..]);
283 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", [3, 0, 0, 0, b'f', b'o', b'o']);
284 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
285 iter(
286 "CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt",
287 "FOO".bytes().map(|ch| ch.to_ascii_lowercase()),
288 );
289
290 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", b"foo");
291 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
292 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
293 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &[b'f', b'o', b'o']);
294 }
295
296 #[test]
297 fn test_base58_successes() {
298 for (encoded, hash) in [
299 ("11111111111111111111111111111111", CryptoHash::new()),
300 ("CjNSmWXTWhC3EhRVtqLhRmWMTkRbU96wUACqxMtV1uGf", hash(&[0, 1, 2])),
301 ] {
302 assert_eq!(encoded, hash.to_string());
303 assert_eq!(hash, CryptoHash::from_str(encoded).unwrap());
304
305 let json = format!("\"{}\"", encoded);
306 assert_eq!(json, serde_json::to_string(&hash).unwrap());
307 assert_eq!(hash, serde_json::from_str::<CryptoHash>(&json).unwrap());
308 }
309 }
310
311 #[test]
312 fn test_from_str_failures() {
313 fn test(input: &str, want_err: &str) {
314 match CryptoHash::from_str(input) {
315 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
316 Err(err) => {
317 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
318 }
319 }
320 }
321
322 test("foo-bar-baz", "provided string contained invalid character '-' at byte 3");
324
325 for encoded in &[
327 "CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf".to_string(),
328 "".to_string(),
329 "1".repeat(31),
330 "1".repeat(33),
331 "1".repeat(1000),
332 ] {
333 test(encoded, "incorrect length for hash");
334 }
335 }
336
337 #[test]
338 fn test_serde_deserialize_failures() {
339 fn test(input: &str, want_err: &str) {
340 match serde_json::from_str::<CryptoHash>(input) {
341 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
342 Err(err) => {
343 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
344 }
345 }
346 }
347
348 test("\"foo-bar-baz\"", "provided string contained invalid character");
349 for encoded in &[
351 "\"CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf\"".to_string(),
352 "\"\"".to_string(),
353 format!("\"{}\"", "1".repeat(31)),
354 format!("\"{}\"", "1".repeat(33)),
355 format!("\"{}\"", "1".repeat(1000)),
356 ] {
357 test(encoded, "invalid length");
358 }
359 }
360}