1use borsh::BorshSerialize;
2use near_schema_checker_lib::ProtocolSchema;
3use serde::{Deserializer, Serializer};
4use sha2::Digest;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use std::io::Write;
8
9#[derive(
11 Copy,
12 Clone,
13 PartialEq,
14 Eq,
15 PartialOrd,
16 Ord,
17 derive_more::AsRef,
18 derive_more::AsMut,
19 arbitrary::Arbitrary,
20 borsh::BorshDeserialize,
21 borsh::BorshSerialize,
22 ProtocolSchema,
23)]
24#[as_ref(forward)]
25#[as_mut(forward)]
26pub struct CryptoHash(pub [u8; 32]);
27
28impl CryptoHash {
29 pub const LENGTH: usize = 32;
30
31 pub const fn new() -> Self {
32 Self([0; Self::LENGTH])
33 }
34
35 pub fn hash_bytes(bytes: &[u8]) -> CryptoHash {
37 CryptoHash(sha2::Sha256::digest(bytes).into())
38 }
39
40 pub fn hash_borsh<T: BorshSerialize>(value: T) -> CryptoHash {
47 let mut hasher = sha2::Sha256::default();
48 value.serialize(&mut hasher).unwrap();
49 CryptoHash(hasher.finalize().into())
50 }
51
52 pub fn hash_borsh_iter<I>(values: I) -> CryptoHash
60 where
61 I: IntoIterator,
62 I::IntoIter: ExactSizeIterator,
63 I::Item: BorshSerialize,
64 {
65 let iter = values.into_iter();
66 let n = u32::try_from(iter.len()).unwrap();
67 let mut hasher = sha2::Sha256::default();
68 hasher.write_all(&n.to_le_bytes()).unwrap();
69 let count =
70 iter.inspect(|value| BorshSerialize::serialize(&value, &mut hasher).unwrap()).count();
71 assert_eq!(n as usize, count);
72 CryptoHash(hasher.finalize().into())
73 }
74
75 pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
76 &self.0
77 }
78
79 fn to_base58_impl<Out>(self, visitor: impl FnOnce(&str) -> Out) -> Out {
85 let mut buffer = [0u8; 45];
89 let len = bs58::encode(self).into(&mut buffer[..]).unwrap();
90 let value = std::str::from_utf8(&buffer[..len]).unwrap();
91 visitor(value)
92 }
93
94 fn from_base58_impl(encoded: &str) -> Decode58Result {
100 let mut result = Self::new();
101 match bs58::decode(encoded).into(&mut result.0) {
102 Ok(len) if len == result.0.len() => Decode58Result::Ok(result),
103 Ok(_) | Err(bs58::decode::Error::BufferTooSmall) => Decode58Result::BadLength,
104 Err(err) => Decode58Result::Err(err),
105 }
106 }
107}
108
109#[cfg(feature = "schemars")]
110impl schemars::JsonSchema for CryptoHash {
111 fn schema_name() -> std::borrow::Cow<'static, str> {
112 "CryptoHash".to_string().into()
113 }
114
115 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
116 String::json_schema(generator)
117 }
118}
119
120enum Decode58Result {
122 Ok(CryptoHash),
124 BadLength,
126 Err(bs58::decode::Error),
129}
130
131impl Default for CryptoHash {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl serde::Serialize for CryptoHash {
138 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
139 where
140 S: Serializer,
141 {
142 self.to_base58_impl(|encoded| serializer.serialize_str(encoded))
143 }
144}
145
146struct Visitor;
151
152impl<'de> serde::de::Visitor<'de> for Visitor {
153 type Value = CryptoHash;
154
155 fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 fmt.write_str("base58-encoded 256-bit hash")
157 }
158
159 fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
160 match CryptoHash::from_base58_impl(s) {
161 Decode58Result::Ok(result) => Ok(result),
162 Decode58Result::BadLength => Err(E::invalid_length(s.len(), &self)),
163 Decode58Result::Err(err) => Err(E::custom(err)),
164 }
165 }
166}
167
168impl<'de> serde::Deserialize<'de> for CryptoHash {
169 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
170 where
171 D: Deserializer<'de>,
172 {
173 deserializer.deserialize_str(Visitor)
174 }
175}
176
177impl std::str::FromStr for CryptoHash {
178 type Err = Box<dyn std::error::Error + Send + Sync>;
179
180 fn from_str(encoded: &str) -> Result<Self, Self::Err> {
182 match Self::from_base58_impl(encoded) {
183 Decode58Result::Ok(result) => Ok(result),
184 Decode58Result::BadLength => Err("incorrect length for hash".into()),
185 Decode58Result::Err(err) => Err(err.into()),
186 }
187 }
188}
189
190impl TryFrom<&[u8]> for CryptoHash {
191 type Error = Box<dyn std::error::Error + Send + Sync>;
192
193 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
194 Ok(CryptoHash(bytes.try_into()?))
195 }
196}
197
198impl From<CryptoHash> for Vec<u8> {
199 fn from(hash: CryptoHash) -> Vec<u8> {
200 hash.0.to_vec()
201 }
202}
203
204impl From<&CryptoHash> for Vec<u8> {
205 fn from(hash: &CryptoHash) -> Vec<u8> {
206 hash.0.to_vec()
207 }
208}
209
210impl From<CryptoHash> for [u8; CryptoHash::LENGTH] {
211 fn from(hash: CryptoHash) -> [u8; CryptoHash::LENGTH] {
212 hash.0
213 }
214}
215
216impl fmt::Debug for CryptoHash {
217 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
218 fmt::Display::fmt(self, fmtr)
219 }
220}
221
222impl fmt::Display for CryptoHash {
223 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
224 self.to_base58_impl(|encoded| fmtr.write_str(encoded))
225 }
226}
227
228impl Hash for CryptoHash {
231 fn hash<H: Hasher>(&self, state: &mut H) {
232 state.write(self.as_ref());
233 }
234}
235
236pub fn hash(data: &[u8]) -> CryptoHash {
247 CryptoHash::hash_bytes(data)
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use std::str::FromStr;
254
255 #[derive(serde::Deserialize, serde::Serialize)]
256 struct Struct {
257 hash: CryptoHash,
258 }
259
260 #[test]
261 fn test_hash_borsh() {
262 fn value<T: BorshSerialize>(want: &str, value: T) {
263 assert_eq!(want, CryptoHash::hash_borsh(&value).to_string());
264 }
265
266 fn slice<T: BorshSerialize>(want: &str, slice: &[T]) {
267 assert_eq!(want, CryptoHash::hash_borsh(slice).to_string());
268 iter(want, slice.iter());
269 iter(want, slice);
270 }
271
272 fn iter<I>(want: &str, iter: I)
273 where
274 I: IntoIterator,
275 I::IntoIter: ExactSizeIterator,
276 I::Item: BorshSerialize,
277 {
278 assert_eq!(want, CryptoHash::hash_borsh_iter(iter).to_string());
279 }
280
281 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo");
282 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
283 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &b"foo"[..]);
284 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", [3, 0, 0, 0, b'f', b'o', b'o']);
285 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
286 iter(
287 "CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt",
288 "FOO".bytes().map(|ch| ch.to_ascii_lowercase()),
289 );
290
291 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", b"foo");
292 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
293 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
294 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &[b'f', b'o', b'o']);
295 }
296
297 #[test]
298 fn test_base58_successes() {
299 for (encoded, hash) in [
300 ("11111111111111111111111111111111", CryptoHash::new()),
301 ("CjNSmWXTWhC3EhRVtqLhRmWMTkRbU96wUACqxMtV1uGf", hash(&[0, 1, 2])),
302 ] {
303 assert_eq!(encoded, hash.to_string());
304 assert_eq!(hash, CryptoHash::from_str(encoded).unwrap());
305
306 let json = format!("\"{}\"", encoded);
307 assert_eq!(json, serde_json::to_string(&hash).unwrap());
308 assert_eq!(hash, serde_json::from_str::<CryptoHash>(&json).unwrap());
309 }
310 }
311
312 #[test]
313 fn test_from_str_failures() {
314 fn test(input: &str, want_err: &str) {
315 match CryptoHash::from_str(input) {
316 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
317 Err(err) => {
318 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
319 }
320 }
321 }
322
323 test("foo-bar-baz", "provided string contained invalid character '-' at byte 3");
325
326 for encoded in &[
328 "CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf".to_string(),
329 "".to_string(),
330 "1".repeat(31),
331 "1".repeat(33),
332 "1".repeat(1000),
333 ] {
334 test(encoded, "incorrect length for hash");
335 }
336 }
337
338 #[test]
339 fn test_serde_deserialize_failures() {
340 fn test(input: &str, want_err: &str) {
341 match serde_json::from_str::<CryptoHash>(input) {
342 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
343 Err(err) => {
344 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
345 }
346 }
347 }
348
349 test("\"foo-bar-baz\"", "provided string contained invalid character");
350 for encoded in &[
352 "\"CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf\"".to_string(),
353 "\"\"".to_string(),
354 format!("\"{}\"", "1".repeat(31)),
355 format!("\"{}\"", "1".repeat(33)),
356 format!("\"{}\"", "1".repeat(1000)),
357 ] {
358 test(encoded, "invalid length");
359 }
360 }
361}