bromberg_sl2/
hash_matrix.rs

1#[cfg(not(feature = "std"))]
2use alloc::string::String;
3#[cfg(not(feature = "std"))]
4use core::fmt::Debug;
5use core::ops::Mul;
6
7#[derive(PartialEq, Eq, Debug)]
8// big-end first; does this matter?
9// TODO try using u64s or u32s instead for performance.
10struct U256(u128, u128);
11
12#[cfg(test)]
13use num_bigint::{BigUint, ToBigUint};
14
15#[cfg(test)]
16impl ToBigUint for U256 {
17    fn to_biguint(&self) -> Option<BigUint> {
18        Some(
19            self.0.to_biguint().unwrap() * (1.to_biguint().unwrap() << 128)
20                + self.1.to_biguint().unwrap(),
21        )
22    }
23}
24
25/// The type of hash values. Takes up 512 bits of space.
26/// Can be created only by composition of the provided
27/// [`BrombergHashable`](trait.BrombergHashable.html)
28/// instances, since not all 512-bit sequences are valid hashes
29/// (in fact, fewer than 1/4 of them will be valid).
30#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
31pub struct HashMatrix(u128, u128, u128, u128);
32
33impl HashMatrix {
34    #[inline]
35    pub fn from_hex(hex: &str) -> Result<Self, String> {
36        if !hex.is_ascii() || hex.len() != 128 {
37            return Err(format!("invalid hex string: {:?}", hex));
38        }
39
40        let hex_bytes = hex.as_bytes();
41        let a = hex_bytes_to_u128(&hex_bytes[..32])?;
42        let b = hex_bytes_to_u128(&hex_bytes[32..64])?;
43        let c = hex_bytes_to_u128(&hex_bytes[64..96])?;
44        let d = hex_bytes_to_u128(&hex_bytes[96..])?;
45        Ok(Self(a, b, c, d))
46    }
47
48    /// Produce a hex digest of the hash. This will be a 128 hex digits.
49    #[inline]
50    pub fn to_hex(self) -> String {
51        format!(
52            "{:032x}{:032x}{:032x}{:032x}",
53            self.0, self.1, self.2, self.3
54        )
55    }
56
57    #[must_use]
58    #[inline]
59    pub fn to_be_bytes(&self) -> [u8; 64] {
60        let mut result = [0u8; 64];
61        result[..16].copy_from_slice(&self.0.to_be_bytes());
62        result[16..32].copy_from_slice(&self.1.to_be_bytes());
63        result[32..48].copy_from_slice(&self.2.to_be_bytes());
64        result[48..].copy_from_slice(&self.3.to_be_bytes());
65        result
66    }
67
68    #[must_use]
69    #[inline]
70    pub fn to_le_bytes(&self) -> [u8; 64] {
71        let mut result = [0u8; 64];
72        result[..16].copy_from_slice(&self.0.to_le_bytes());
73        result[16..32].copy_from_slice(&self.1.to_le_bytes());
74        result[32..48].copy_from_slice(&self.2.to_le_bytes());
75        result[48..].copy_from_slice(&self.3.to_le_bytes());
76        result
77    }
78}
79
80impl Default for HashMatrix {
81    fn default() -> Self {
82        I
83    }
84}
85
86impl Mul for HashMatrix {
87    type Output = Self;
88    #[inline]
89    fn mul(self, rhs: Self) -> Self {
90        matmul(self, rhs)
91    }
92}
93
94pub(crate) const A: HashMatrix = HashMatrix(1, 2, 0, 1);
95
96pub(crate) const B: HashMatrix = HashMatrix(1, 0, 2, 1);
97
98pub static I: HashMatrix = HashMatrix(1, 0, 0, 1);
99
100const SUCC_P: u128 = 1 << 127;
101const P: u128 = SUCC_P - 1;
102
103const LO_MASK: u128 = 0xffff_ffff_ffff_ffff;
104
105#[inline]
106const fn mul(x: u128, y: u128) -> U256 {
107    let x_lo = x & LO_MASK;
108    let y_lo = y & LO_MASK;
109
110    let x_hi = x >> 64;
111    let y_hi = y >> 64;
112
113    let x_hi_y_lo = x_hi.wrapping_mul(y_lo);
114    let y_hi_x_lo = y_hi.wrapping_mul(x_lo);
115
116    let x_hi_y_lo_shifted = x_hi_y_lo << 64;
117    let y_hi_x_lo_shifted = y_hi_x_lo << 64;
118
119    let (lo_sum_1, carry_bool_1) = x_hi_y_lo_shifted.overflowing_add(y_hi_x_lo_shifted);
120    let (lo_sum_2, carry_bool_2) = lo_sum_1.overflowing_add(x_lo.wrapping_mul(y_lo));
121    let carry = carry_bool_1 as u128 + carry_bool_2 as u128;
122
123    U256(
124        x_hi.wrapping_mul(y_hi)
125            .wrapping_add(x_hi_y_lo_shifted >> 64)
126            .wrapping_add(y_hi_x_lo_shifted >> 64)
127            .wrapping_add(carry),
128        lo_sum_2,
129    )
130}
131
132#[inline]
133const fn add(x: U256, y: U256) -> U256 {
134    // NOTE: x and y are guaranteed to be <=
135    // (2^127 - 2)^2 = 2^254 - 4 * 2^127 + 4,
136    // so I think we don't have to worry about carries out of here.
137    let (low, carry) = x.1.overflowing_add(y.1);
138    let high = x.0 + y.0 + carry as u128;
139    U256(high, low)
140}
141
142// algorithm as described by Dresdenboy in "Fast calculations
143// modulo small mersenne primes like M61" at
144// https://www.mersenneforum.org/showthread.php?t=1955
145// I tried using a handwritten version of this that avoided the U256s,
146// but it was like half as fast somehow.
147#[inline]
148const fn mod_p_round_1(n: U256) -> U256 {
149    let low_bits = n.1 & P; // 127 bits of input
150    let intermediate_bits = (n.0 << 1) | (n.1 >> 127); // 128 of the 129 additional bits
151    let high_bit = n.0 >> 127;
152    let (sum, carry_bool) = low_bits.overflowing_add(intermediate_bits);
153    U256(carry_bool as u128 + high_bit, sum)
154}
155
156#[inline]
157const fn mod_p_round_2(n: U256) -> u128 {
158    let low_bits = n.1 & P; // 127 bits of input
159    let intermediate_bits = (n.0 << 1) | (n.1 >> 127); // 128 of the 129 additional bits
160    low_bits + intermediate_bits
161}
162
163#[inline]
164const fn mod_p_round_3(n: u128) -> u128 {
165    let low_bits = n & P; // 127 bits of input
166    let intermediate_bit = n >> 127; // 128 of the 129 additional bits
167    low_bits + intermediate_bit
168}
169
170#[inline]
171const fn constmod_p(n: U256) -> u128 {
172    let n1 = mod_p_round_1(n);
173    let n2 = mod_p_round_2(n1);
174    let n3 = mod_p_round_3(n2);
175
176    ((n3 + 1) & P).saturating_sub(1)
177}
178
179#[inline]
180fn mod_p(mut n: U256) -> u128 {
181    // n is at most 255 bits wide
182    if n.0 != 0 {
183        n = mod_p_round_1(n);
184    }
185    // n is at most 129 bits wide
186    let mut n_small = if n.0 != 0 || (n.1 > P) {
187        mod_p_round_2(n)
188    } else {
189        n.1
190    };
191    // n is at most 128 bits wide
192    if n_small > P {
193        n_small = mod_p_round_3(n_small);
194    }
195    // n is at most 127 bits wide
196
197    if n_small == P {
198        0
199    } else {
200        n_small
201    }
202}
203
204#[inline]
205pub fn matmul(a: HashMatrix, b: HashMatrix) -> HashMatrix {
206    HashMatrix(
207        mod_p(add(mul(a.0, b.0), mul(a.1, b.2))),
208        mod_p(add(mul(a.0, b.1), mul(a.1, b.3))),
209        mod_p(add(mul(a.2, b.0), mul(a.3, b.2))),
210        mod_p(add(mul(a.2, b.1), mul(a.3, b.3))),
211    )
212}
213
214/// Identical results to the `*` operator, but slower. Exposed to provide a
215/// `const` version in case you'd like to compile certain hashes into your
216/// binaries.
217#[must_use]
218#[inline]
219pub const fn constmatmul(a: HashMatrix, b: HashMatrix) -> HashMatrix {
220    HashMatrix(
221        constmod_p(add(mul(a.0, b.0), mul(a.1, b.2))),
222        constmod_p(add(mul(a.0, b.1), mul(a.1, b.3))),
223        constmod_p(add(mul(a.2, b.0), mul(a.3, b.2))),
224        constmod_p(add(mul(a.2, b.1), mul(a.3, b.3))),
225    )
226}
227
228fn hex_bytes_to_u128(hex_bytes: &[u8]) -> Result<u128, String> {
229    let mut hex_bytes = hex_bytes.iter().copied();
230    let mut result = [0u8; 16];
231    for byte in result.iter_mut() {
232        let digit1 = hex_digit_to_u8(hex_bytes.next().unwrap())?;
233        let digit2 = hex_digit_to_u8(hex_bytes.next().unwrap())?;
234        *byte = (digit1 << 4) | digit2;
235    }
236    Ok(u128::from_be_bytes(result))
237}
238
239fn hex_digit_to_u8(hex_digit: u8) -> Result<u8, String> {
240    match hex_digit {
241        b'A'..=b'F' => Ok(hex_digit - b'A' + 10),
242        b'a'..=b'f' => Ok(hex_digit - b'a' + 10),
243        b'0'..=b'9' => Ok(hex_digit - b'0'),
244        _ => Err(format!("invalid hex character: {:?}", hex_digit)),
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::*;
252
253    use alloc::vec::Vec;
254
255    #[test]
256    fn it_works() {
257        assert_eq!(mul(1 << 127, 2), U256(1, 0));
258        assert_eq!(
259            mul(1 << 127, 1 << 127),
260            U256(85070591730234615865843651857942052864, 0)
261        );
262        assert_eq!(mul(4, 4), U256(0, 16));
263        assert_eq!(
264            mul((1 << 127) + 4, (1 << 127) + 4),
265            U256(85070591730234615865843651857942052868, 16)
266        );
267
268        assert_eq!(mod_p(U256(0, P)), 0);
269        assert_eq!(mod_p(U256(0, P + 1)), 1);
270        assert_eq!(mod_p(U256(0, 0)), 0);
271        assert_eq!(mod_p(U256(0, 1)), 1);
272        assert_eq!(mod_p(U256(0, P - 1)), P - 1);
273        assert_eq!(mod_p(U256(0, 1 << 127)), 1);
274        assert_eq!(mod_p(U256(1, P)), 2);
275        assert_eq!(mod_p(U256(1, 0)), 2);
276        assert_eq!(mod_p(U256(P, 0)), 0);
277        assert_eq!(mod_p(U256(P, P)), 0);
278        assert_eq!(mod_p(U256(0, u128::MAX)), 1);
279
280        assert_eq!(
281            HashMatrix(1, 0, 0, 1) * HashMatrix(1, 0, 0, 1),
282            HashMatrix(1, 0, 0, 1)
283        );
284        assert_eq!(
285            HashMatrix(2, 0, 0, 2) * HashMatrix(2, 0, 0, 2),
286            HashMatrix(4, 0, 0, 4)
287        );
288        assert_eq!(
289            HashMatrix(0, 1, 1, 0) * HashMatrix(2, 0, 0, 2),
290            HashMatrix(0, 2, 2, 0)
291        );
292        assert_eq!(
293            HashMatrix(0, 1, 1, 0) * HashMatrix(2, 0, 0, 2),
294            HashMatrix(0, 2, 2, 0)
295        );
296        assert_eq!(
297            HashMatrix(1, 0, 0, 1) * HashMatrix(P, 0, 0, P),
298            HashMatrix(0, 0, 0, 0)
299        );
300        assert_eq!(
301            HashMatrix(1, 0, 0, 1) * HashMatrix(P + 1, P + 5, 2, P),
302            HashMatrix(1, 5, 2, 0)
303        );
304        assert_eq!(
305            HashMatrix(P + 1, P + 3, P + 4, P + 5) * HashMatrix(P + 1, P, P, P + 1),
306            HashMatrix(1, 3, 4, 5)
307        );
308    }
309
310    #[test]
311    fn test_hex_encoding_and_decoding() {
312        let hash = HashMatrix(0, 0, 0, 0);
313        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
314
315        let hash = HashMatrix(0, 0, 0, 1);
316        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
317
318        let hash = HashMatrix(0, 0, 0, 31);
319        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
320
321        let hash = HashMatrix(0, 0, 0, 89);
322        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
323
324        let hash = HashMatrix(0, 0, 0, 1 << 34);
325        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
326
327        let hash = HashMatrix(0, 1 << 31, 0, 1 << 34);
328        assert_eq!(HashMatrix::from_hex(&hash.to_hex()).unwrap(), hash);
329    }
330
331    use quickcheck::*;
332
333    quickcheck! {
334        fn composition(a: Vec<u8>, b: Vec<u8>) -> bool {
335            let mut a = a;
336            let mut b = b;
337            let h1 = hash(&a) * hash(&b);
338            a.append(&mut b);
339            hash(&a) == h1
340        }
341    }
342
343    quickcheck! {
344        fn hex_encoding_and_decoding(bytes: Vec<u8>) -> bool {
345            let hash = hash(&bytes);
346            HashMatrix::from_hex(&hash.to_hex()).unwrap() == hash
347        }
348    }
349
350    quickcheck! {
351        fn mul_check(a: u128, b: u128) -> bool {
352            use num_bigint::*;
353            let res = mul(a, b);
354
355            a.to_biguint().unwrap() * b.to_biguint().unwrap()
356                == res.to_biguint().unwrap()
357        }
358    }
359
360    quickcheck! {
361        fn add_check(a: u128, b: u128, c: u128, d: u128) -> bool {
362            let res = add(mul(a, b), mul(c, d));
363
364            let big_res = a.to_biguint().unwrap() * b.to_biguint().unwrap()
365                + c.to_biguint().unwrap() * d.to_biguint().unwrap();
366
367            res.to_biguint().unwrap() == big_res
368        }
369    }
370
371    quickcheck! {
372        fn mod_p_check(a: u128, b: u128, c: u128, d: u128) -> bool {
373            let res = mod_p(add(mul(a, b), mul(c, d)));
374
375            let big_res = (a.to_biguint().unwrap() * b.to_biguint().unwrap()
376                + c.to_biguint().unwrap() * d.to_biguint().unwrap())
377                % P.to_biguint().unwrap();
378
379            res.to_biguint().unwrap() == big_res
380        }
381    }
382
383    quickcheck! {
384        fn collision_search(a: Vec<u8>, b: Vec<u8>) -> bool {
385            let ares = hash(&a);
386            let bres = hash(&b);
387            ares != bres || a == b
388        }
389    }
390
391    #[cfg(feature = "std")]
392    quickcheck! {
393        fn par_equiv(a: Vec<u8>) -> bool {
394            let h0 = hash(&a);
395            let h1 = hash_par(&a);
396            h0 == h1
397        }
398    }
399}