light_hasher/
poseidon.rs

1use thiserror::{self, Error};
2
3use crate::{
4    errors::HasherError,
5    zero_bytes::{poseidon::ZERO_BYTES, ZeroBytes},
6    zero_indexed_leaf::poseidon::ZERO_INDEXED_LEAF,
7    Hash, Hasher,
8};
9
10#[derive(Debug, Error, PartialEq)]
11pub enum PoseidonSyscallError {
12    #[error("Invalid parameters.")]
13    InvalidParameters,
14    #[error("Invalid endianness.")]
15    InvalidEndianness,
16    #[error("Invalid number of inputs. Maximum allowed is 12.")]
17    InvalidNumberOfInputs,
18    #[error("Input is an empty slice.")]
19    EmptyInput,
20    #[error(
21        "Invalid length of the input. The length matching the modulus of the prime field is 32."
22    )]
23    InvalidInputLength,
24    #[error("Failed to convert bytest into a prime field element.")]
25    BytesToPrimeFieldElement,
26    #[error("Input is larger than the modulus of the prime field.")]
27    InputLargerThanModulus,
28    #[error("Failed to convert a vector of bytes into an array.")]
29    VecToArray,
30    #[error("Failed to convert the number of inputs from u64 to u8.")]
31    U64Tou8,
32    #[error("Failed to convert bytes to BigInt")]
33    BytesToBigInt,
34    #[error("Invalid width. Choose a width between 2 and 16 for 1 to 15 inputs.")]
35    InvalidWidthCircom,
36    #[error("Unexpected error")]
37    Unexpected,
38}
39impl From<u64> for PoseidonSyscallError {
40    fn from(error: u64) -> Self {
41        match error {
42            1 => PoseidonSyscallError::InvalidParameters,
43            2 => PoseidonSyscallError::InvalidEndianness,
44            3 => PoseidonSyscallError::InvalidNumberOfInputs,
45            4 => PoseidonSyscallError::EmptyInput,
46            5 => PoseidonSyscallError::InvalidInputLength,
47            6 => PoseidonSyscallError::BytesToPrimeFieldElement,
48            7 => PoseidonSyscallError::InputLargerThanModulus,
49            8 => PoseidonSyscallError::VecToArray,
50            9 => PoseidonSyscallError::U64Tou8,
51            10 => PoseidonSyscallError::BytesToBigInt,
52            11 => PoseidonSyscallError::InvalidWidthCircom,
53            _ => PoseidonSyscallError::Unexpected,
54        }
55    }
56}
57
58impl From<PoseidonSyscallError> for u64 {
59    fn from(error: PoseidonSyscallError) -> Self {
60        match error {
61            PoseidonSyscallError::InvalidParameters => 1,
62            PoseidonSyscallError::InvalidEndianness => 2,
63            PoseidonSyscallError::InvalidNumberOfInputs => 3,
64            PoseidonSyscallError::EmptyInput => 4,
65            PoseidonSyscallError::InvalidInputLength => 5,
66            PoseidonSyscallError::BytesToPrimeFieldElement => 6,
67            PoseidonSyscallError::InputLargerThanModulus => 7,
68            PoseidonSyscallError::VecToArray => 8,
69            PoseidonSyscallError::U64Tou8 => 9,
70            PoseidonSyscallError::BytesToBigInt => 10,
71            PoseidonSyscallError::InvalidWidthCircom => 11,
72            PoseidonSyscallError::Unexpected => 12,
73        }
74    }
75}
76
77#[derive(Debug, Clone, Copy)]
78pub struct Poseidon;
79
80impl Hasher for Poseidon {
81    const ID: u8 = 0;
82
83    fn hash(val: &[u8]) -> Result<Hash, HasherError> {
84        Self::hashv(&[val])
85    }
86
87    fn hashv(vals: &[&[u8]]) -> Result<Hash, HasherError> {
88        // Perform the calculation inline, calling this from within a program is
89        // not supported.
90        #[cfg(not(target_os = "solana"))]
91        {
92            use ark_bn254::Fr;
93            use light_poseidon::{Poseidon, PoseidonBytesHasher};
94
95            let mut hasher = Poseidon::<Fr>::new_circom(vals.len())?;
96            let res = hasher.hash_bytes_be(vals)?;
97
98            Ok(res)
99        }
100        // Call via a system call to perform the calculation.
101        #[cfg(target_os = "solana")]
102        {
103            use crate::HASH_BYTES;
104            // TODO: reenable once LightHasher refactor is merged
105            // solana_program::msg!("remove len check onchain.");
106            // for val in vals {
107            //     if val.len() != 32 {
108            //         return Err(HasherError::InvalidInputLength(val.len()));
109            //     }
110            // }
111            let mut hash_result = [0; HASH_BYTES];
112            let result = unsafe {
113                crate::syscalls::sol_poseidon(
114                    0, // bn254
115                    0, // big-endian
116                    vals as *const _ as *const u8,
117                    vals.len() as u64,
118                    &mut hash_result as *mut _ as *mut u8,
119                )
120            };
121
122            match result {
123                0 => Ok(hash_result),
124                e => Err(HasherError::from(PoseidonSyscallError::from(e))),
125            }
126        }
127    }
128
129    fn zero_bytes() -> ZeroBytes {
130        ZERO_BYTES
131    }
132
133    fn zero_indexed_leaf() -> [u8; 32] {
134        ZERO_INDEXED_LEAF
135    }
136}