indexed_merkle_tree/
lib.rs

1extern crate alloc;
2
3pub mod error;
4pub mod node;
5pub mod tree;
6
7use alloc::string::ToString;
8use anyhow::{anyhow, Result};
9use borsh::{BorshDeserialize, BorshSerialize};
10use error::{MerkleTreeError, MerkleTreeResult};
11use num_bigint::BigUint;
12use num_traits::Num;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15
16#[cfg(feature = "bls")]
17use bls12_381::Scalar;
18
19#[derive(
20    Serialize, Deserialize, BorshSerialize, BorshDeserialize, Debug, Clone, Copy, PartialEq, Eq,
21)]
22pub struct Hash([u8; 32]);
23
24#[cfg(feature = "bls")]
25impl TryFrom<Hash> for Scalar {
26    type Error = anyhow::Error;
27
28    fn try_from(value: Hash) -> Result<Scalar, Self::Error> {
29        let mut byte_array = [0u8; 32];
30        byte_array.copy_from_slice(value.as_ref());
31        byte_array.reverse();
32
33        let val =
34            [
35                u64::from_le_bytes(
36                    byte_array[0..8].try_into().map_err(|_| {
37                        anyhow!(format!("slice to array: [0..8] for hash: {value:?}"))
38                    })?,
39                ),
40                u64::from_le_bytes(byte_array[8..16].try_into().map_err(|_| {
41                    anyhow!(format!("slice to array: [8..16] for hash: {value:?}"))
42                })?),
43                u64::from_le_bytes(byte_array[16..24].try_into().map_err(|_| {
44                    anyhow!(format!("slice to array: [16..24] for hash: {value:?}"))
45                })?),
46                u64::from_le_bytes(byte_array[24..32].try_into().map_err(|_| {
47                    anyhow!(format!("slice to array: [24..42] for hash: {value:?}"))
48                })?),
49            ];
50
51        Ok(Scalar::from_raw(val))
52    }
53}
54
55impl Hash {
56    pub const fn new(bytes: [u8; 32]) -> Self {
57        Hash(bytes)
58    }
59
60    pub fn from_hex(hex_str: &str) -> MerkleTreeResult<Self> {
61        let mut bytes = [0u8; 32];
62        hex::decode_to_slice(hex_str, &mut bytes)
63            .map_err(|e| anyhow!(MerkleTreeError::InvalidFormatError(e.to_string())))?;
64        Ok(Hash(bytes))
65    }
66
67    #[cfg(feature = "std")]
68    pub fn to_hex(&self) -> String {
69        hex::encode(self.0)
70    }
71
72    #[cfg(not(feature = "std"))]
73    pub fn to_hex(&self) -> [u8; 64] {
74        // This is correct, as 32 bytes become 64 hex characters
75        let mut hex = [0u8; 64];
76        hex::encode_to_slice(self.0, &mut hex)
77            .expect("The output is exactly twice the size of the input");
78        hex
79    }
80}
81
82impl AsRef<[u8]> for Hash {
83    fn as_ref(&self) -> &[u8] {
84        &self.0
85    }
86}
87
88#[cfg(feature = "std")]
89impl std::fmt::Display for Hash {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(f, "{}", self.to_hex())
92    }
93}
94
95pub const MODULUS: &str = "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001";
96
97/// Computes the SHA256 hash of the given string.
98///
99/// This function takes a string reference as input and returns its SHA256 hash in hexadecimal format. We're using the `crypto-hash` crate to compute the hash.
100/// It is used to ensure data integrity and uniqueness within the Merkle Tree structure.
101///
102/// # Arguments
103/// * `input` - A reference to the string to be hashed.
104///
105/// # Returns
106/// A `String` representing the hexadecimal SHA256 hash of the input.
107/// TODO: Implement the `sha256` function that computes the SHA256 hash of the given string but with the `sha2` crate (should return a [u8; 32] bc we want to use that in the future everywhere instead of strings).
108pub fn sha256(input: &[u8]) -> Hash {
109    let mut hasher = Sha256::new();
110    hasher.update(input);
111    Hash(hasher.finalize().into())
112}
113
114/// Computes the SHA256 hash of the given string and reduces it modulo the BLS12-381 curve modulus.
115///
116/// This function takes a string reference as input, computes its SHA256 hash, and then reduces
117/// the hash modulo the BLS12-381 curve modulus to ensure it fits within its field Fp.
118/// The resulting value is returned in hexadecimal format.
119///
120/// # Arguments
121/// * `input` - A reference to the string to be hashed.
122///
123/// # Returns
124/// A `String` representing the hexadecimal SHA256 hash of the input reduced modulo the BLS12-381 curve modulus.
125
126pub fn sha256_mod(input: &[u8]) -> Hash {
127    let hash = sha256(input);
128    let hash_bigint = BigUint::from_bytes_be(hash.as_ref());
129    let modulus = BigUint::from_str_radix(MODULUS, 16).expect("Invalid modulus");
130    let modded_hash = hash_bigint % modulus;
131    let mut bytes = modded_hash.to_bytes_be();
132    if bytes.len() < 32 {
133        bytes = core::iter::repeat(0)
134            .take(32 - bytes.len())
135            .chain(bytes)
136            .collect();
137    }
138    Hash::new(bytes.try_into().unwrap())
139}