1#[cfg(not(feature = "std"))]
2use alloc::string::String;
3#[cfg(not(feature = "std"))]
4use core::fmt::Debug;
5use core::ops::Mul;
6
7#[derive(PartialEq, Eq, Debug)]
8struct U256(u128, u128);
11
12#[cfg(test)]
13use num_bigint::{BigUint, ToBigUint};
14
15#[cfg(test)]
16impl ToBigUint for U256 {
17 fn to_biguint(&self) -> Option<BigUint> {
18 Some(
19 self.0.to_biguint().unwrap() * (1.to_biguint().unwrap() << 128)
20 + self.1.to_biguint().unwrap(),
21 )
22 }
23}
24
25#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
31pub struct HashMatrix(u128, u128, u128, u128);
32
33impl HashMatrix {
34 #[inline]
35 pub fn from_hex(hex: &str) -> Result<Self, String> {
36 if !hex.is_ascii() || hex.len() != 128 {
37 return Err(format!("invalid hex string: {:?}", hex));
38 }
39
40 let hex_bytes = hex.as_bytes();
41 let a = hex_bytes_to_u128(&hex_bytes[..32])?;
42 let b = hex_bytes_to_u128(&hex_bytes[32..64])?;
43 let c = hex_bytes_to_u128(&hex_bytes[64..96])?;
44 let d = hex_bytes_to_u128(&hex_bytes[96..])?;
45 Ok(Self(a, b, c, d))
46 }
47
48 #[inline]
50 pub fn to_hex(self) -> String {
51 format!(
52 "{:032x}{:032x}{:032x}{:032x}",
53 self.0, self.1, self.2, self.3
54 )
55 }
56
57 #[must_use]
58 #[inline]
59 pub fn to_be_bytes(&self) -> [u8; 64] {
60 let mut result = [0u8; 64];
61 result[..16].copy_from_slice(&self.0.to_be_bytes());
62 result[16..32].copy_from_slice(&self.1.to_be_bytes());
63 result[32..48].copy_from_slice(&self.2.to_be_bytes());
64 result[48..].copy_from_slice(&self.3.to_be_bytes());
65 result
66 }
67
68 #[must_use]
69 #[inline]
70 pub fn to_le_bytes(&self) -> [u8; 64] {
71 let mut result = [0u8; 64];
72 result[..16].copy_from_slice(&self.0.to_le_bytes());
73 result[16..32].copy_from_slice(&self.1.to_le_bytes());
74 result[32..48].copy_from_slice(&self.2.to_le_bytes());
75 result[48..].copy_from_slice(&self.3.to_le_bytes());
76 result
77 }
78}
79
80impl Default for HashMatrix {
81 fn default() -> Self {
82 I
83 }
84}
85
86impl Mul for HashMatrix {
87 type Output = Self;
88 #[inline]
89 fn mul(self, rhs: Self) -> Self {
90 matmul(self, rhs)
91 }
92}
93
94pub(crate) const A: HashMatrix = HashMatrix(1, 2, 0, 1);
95
96pub(crate) const B: HashMatrix = HashMatrix(1, 0, 2, 1);
97
98pub static I: HashMatrix = HashMatrix(1, 0, 0, 1);
99
100const SUCC_P: u128 = 1 << 127;
101const P: u128 = SUCC_P - 1;
102
103const LO_MASK: u128 = 0xffff_ffff_ffff_ffff;
104
105#[inline]
106const fn mul(x: u128, y: u128) -> U256 {
107 let x_lo = x & LO_MASK;
108 let y_lo = y & LO_MASK;
109
110 let x_hi = x >> 64;
111 let y_hi = y >> 64;
112
113 let x_hi_y_lo = x_hi.wrapping_mul(y_lo);
114 let y_hi_x_lo = y_hi.wrapping_mul(x_lo);
115
116 let x_hi_y_lo_shifted = x_hi_y_lo << 64;
117 let y_hi_x_lo_shifted = y_hi_x_lo << 64;
118
119 let (lo_sum_1, carry_bool_1) = x_hi_y_lo_shifted.overflowing_add(y_hi_x_lo_shifted);
120 let (lo_sum_2, carry_bool_2) = lo_sum_1.overflowing_add(x_lo.wrapping_mul(y_lo));
121 let carry = carry_bool_1 as u128 + carry_bool_2 as u128;
122
123 U256(
124 x_hi.wrapping_mul(y_hi)
125 .wrapping_add(x_hi_y_lo_shifted >> 64)
126 .wrapping_add(y_hi_x_lo_shifted >> 64)
127 .wrapping_add(carry),
128 lo_sum_2,
129 )
130}
131
132#[inline]
133const fn add(x: U256, y: U256) -> U256 {
134 let (low, carry) = x.1.overflowing_add(y.1);
138 let high = x.0 + y.0 + carry as u128;
139 U256(high, low)
140}
141
142#[inline]
148const fn mod_p_round_1(n: U256) -> U256 {
149 let low_bits = n.1 & P; let intermediate_bits = (n.0 << 1) | (n.1 >> 127); let high_bit = n.0 >> 127;
152 let (sum, carry_bool) = low_bits.overflowing_add(intermediate_bits);
153 U256(carry_bool as u128 + high_bit, sum)
154}
155
156#[inline]
157const fn mod_p_round_2(n: U256) -> u128 {
158 let low_bits = n.1 & P; let intermediate_bits = (n.0 << 1) | (n.1 >> 127); low_bits + intermediate_bits
161}
162
163#[inline]
164const fn mod_p_round_3(n: u128) -> u128 {
165 let low_bits = n & P; let intermediate_bit = n >> 127; low_bits + intermediate_bit
168}
169
170#[inline]
171const fn constmod_p(n: U256) -> u128 {
172 let n1 = mod_p_round_1(n);
173 let n2 = mod_p_round_2(n1);
174 let n3 = mod_p_round_3(n2);
175
176 ((n3 + 1) & P).saturating_sub(1)
177}
178
179#[inline]
180fn mod_p(mut n: U256) -> u128 {
181 if n.0 != 0 {
183 n = mod_p_round_1(n);
184 }
185 let mut n_small = if n.0 != 0 || (n.1 > P) {
187 mod_p_round_2(n)
188 } else {
189 n.1
190 };
191 if n_small > P {
193 n_small = mod_p_round_3(n_small);
194 }
195 if n_small == P {
198 0
199 } else {
200 n_small
201 }
202}
203
204#[inline]
205pub fn matmul(a: HashMatrix, b: HashMatrix) -> HashMatrix {
206 HashMatrix(
207 mod_p(add(mul(a.0, b.0), mul(a.1, b.2))),
208 mod_p(add(mul(a.0, b.1), mul(a.1, b.3))),
209 mod_p(add(mul(a.2, b.0), mul(a.3, b.2))),
210 mod_p(add(mul(a.2, b.1), mul(a.3, b.3))),
211 )
212}
213
214#[must_use]
218#[inline]
219pub const fn constmatmul(a: HashMatrix, b: HashMatrix) -> HashMatrix {
220 HashMatrix(
221 constmod_p(add(mul(a.0, b.0), mul(a.1, b.2))),
222 constmod_p(add(mul(a.0, b.1), mul(a.1, b.3))),
223 constmod_p(add(mul(a.2, b.0), mul(a.3, b.2))),
224 constmod_p(add(mul(a.2, b.1), mul(a.3, b.3))),
225 )
226}
227
228fn hex_bytes_to_u128(hex_bytes: &[u8]) -> Result<u128, String> {
229 let mut hex_bytes = hex_bytes.iter().copied();
230 let mut result = [0u8; 16];
231 for byte in result.iter_mut() {
232 let digit1 = hex_digit_to_u8(hex_bytes.next().unwrap())?;
233 let digit2 = hex_digit_to_u8(hex_bytes.next().unwrap())?;
234 *byte = (digit1 << 4) | digit2;
235 }
236 Ok(u128::from_be_bytes(result))
237}
238
239fn hex_digit_to_u8(hex_digit: u8) -> Result<u8, String> {
240 match hex_digit {
241 b'A'..=b'F' => Ok(hex_digit - b'A' + 10),
242 b'a'..=b'f' => Ok(hex_digit - b'a' + 10),
243 b'0'..=b'9' => Ok(hex_digit - b'0'),
244 _ => Err(format!("invalid hex character: {:?}", hex_digit)),
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::*;
252
253 use alloc::vec::Vec;
254
255 #[test]
256 fn it_works() {
257 assert_eq!(mul(1 << 127, 2), U256(1, 0));
258 assert_eq!(
259 mul(1 << 127, 1 << 127),
260 U256(85070591730234615865843651857942052864, 0)
261 );
262 assert_eq!(mul(4, 4), U256(0, 16));
263 assert_eq!(
264 mul((1 << 127) + 4, (1 << 127) + 4),
265 U256(85070591730234615865843651857942052868, 16)
266 );
267
268 assert_eq!(mod_p(U256(0, P)), 0);
269 assert_eq!(mod_p(U256(0, P + 1)), 1);
270 assert_eq!(mod_p(U256(0, 0)), 0);
271 assert_eq!(mod_p(U256(0, 1)), 1);
272 assert_eq!(mod_p(U256(0, P - 1)), P - 1);
273 assert_eq!(mod_p(U256(0, 1 << 127)), 1);
274 assert_eq!(mod_p(U256(1, P)), 2);
275 assert_eq!(mod_p(U256(1, 0)), 2);
276 assert_eq!(mod_p(U256(P, 0)), 0);
277 assert_eq!(mod_p(U256(P, P)), 0);
278 assert_eq!(mod_p(U256(0, u128::MAX)), 1);
279
280 assert_eq!(
281 HashMatrix(1, 0, 0, 1) * HashMatrix(1, 0, 0, 1),
282 HashMatrix(1, 0, 0, 1)
283 );
284 assert_eq!(
285 HashMatrix(2, 0, 0, 2) * HashMatrix(2, 0, 0, 2),
286 HashMatrix(4, 0, 0, 4)
287 );
288 assert_eq!(
289 HashMatrix(0, 1, 1, 0) * HashMatrix(2, 0, 0, 2),
290 HashMatrix(0, 2, 2, 0)
291 );
292 assert_eq!(
293 HashMatrix(0, 1, 1, 0) * HashMatrix(2, 0, 0, 2),
294 HashMatrix(0, 2, 2, 0)
295 );
296 assert_eq!(
297 HashMatrix(1, 0, 0, 1) * HashMatrix(P, 0, 0, P),
298 HashMatrix(0, 0, 0, 0)
299 );
300 assert_eq!(
301 HashMatrix(1, 0, 0, 1) * HashMatrix(P + 1, P + 5, 2, P),
302 HashMatrix(1, 5, 2, 0)
303 );
304 assert_eq!(
305 HashMatrix(P + 1, P + 3, P + 4, P + 5) * HashMatrix(P + 1, P, P, P + 1),
306 HashMatrix(1, 3, 4, 5)
307 );
308 }
309
310 #[test]
311 fn test_hex_encoding_and_decoding() {
312 let hash = HashMatrix(0, 0, 0, 0);
313 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
314
315 let hash = HashMatrix(0, 0, 0, 1);
316 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
317
318 let hash = HashMatrix(0, 0, 0, 31);
319 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
320
321 let hash = HashMatrix(0, 0, 0, 89);
322 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
323
324 let hash = HashMatrix(0, 0, 0, 1 << 34);
325 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
326
327 let hash = HashMatrix(0, 1 << 31, 0, 1 << 34);
328 assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
329 }
330
331 use quickcheck::*;
332
333 quickcheck! {
334 fn composition(a: Vec<u8>, b: Vec<u8>) -> bool {
335 let mut a = a;
336 let mut b = b;
337 let h1 = hash(&a) * hash(&b);
338 a.append(&mut b);
339 hash(&a) == h1
340 }
341 }
342
343 quickcheck! {
344 fn hex_encoding_and_decoding(bytes: Vec<u8>) -> bool {
345 let hash = hash(&bytes);
346 HashMatrix::from_hex(&hash.to_hex()).unwrap() == hash
347 }
348 }
349
350 quickcheck! {
351 fn mul_check(a: u128, b: u128) -> bool {
352 use num_bigint::*;
353 let res = mul(a, b);
354
355 a.to_biguint().unwrap() * b.to_biguint().unwrap()
356 == res.to_biguint().unwrap()
357 }
358 }
359
360 quickcheck! {
361 fn add_check(a: u128, b: u128, c: u128, d: u128) -> bool {
362 let res = add(mul(a, b), mul(c, d));
363
364 let big_res = a.to_biguint().unwrap() * b.to_biguint().unwrap()
365 + c.to_biguint().unwrap() * d.to_biguint().unwrap();
366
367 res.to_biguint().unwrap() == big_res
368 }
369 }
370
371 quickcheck! {
372 fn mod_p_check(a: u128, b: u128, c: u128, d: u128) -> bool {
373 let res = mod_p(add(mul(a, b), mul(c, d)));
374
375 let big_res = (a.to_biguint().unwrap() * b.to_biguint().unwrap()
376 + c.to_biguint().unwrap() * d.to_biguint().unwrap())
377 % P.to_biguint().unwrap();
378
379 res.to_biguint().unwrap() == big_res
380 }
381 }
382
383 quickcheck! {
384 fn collision_search(a: Vec<u8>, b: Vec<u8>) -> bool {
385 let ares = hash(&a);
386 let bres = hash(&b);
387 ares != bres || a == b
388 }
389 }
390
391 #[cfg(feature = "std")]
392 quickcheck! {
393 fn par_equiv(a: Vec<u8>) -> bool {
394 let h0 = hash(&a);
395 let h1 = hash_par(&a);
396 h0 == h1
397 }
398 }
399}