light_utils/
lib.rs

1use std::{
2    env,
3    io::{self, prelude::*},
4    process::{Command, Stdio},
5    thread::spawn,
6};
7
8use ark_ff::PrimeField;
9use num_bigint::BigUint;
10
11use solana_program::keccak::hashv;
12use thiserror::Error;
13
14pub mod bigint;
15pub mod fee;
16pub mod offset;
17pub mod prime;
18pub mod rand;
19
20#[derive(Debug, Error, PartialEq)]
21pub enum UtilsError {
22    #[error("Invalid input size, expected at most {0}")]
23    InputTooLarge(usize),
24    #[error("Invalid chunk size")]
25    InvalidChunkSize,
26    #[error("Invalid seeds")]
27    InvalidSeeds,
28    #[error("Invalid rollover thresold")]
29    InvalidRolloverThreshold,
30}
31
32// NOTE(vadorovsky): Unfortunately, we need to do it by hand.
33// `num_derive::ToPrimitive` doesn't support data-carrying enums.
34impl From<UtilsError> for u32 {
35    fn from(e: UtilsError) -> u32 {
36        match e {
37            UtilsError::InputTooLarge(_) => 12001,
38            UtilsError::InvalidChunkSize => 12002,
39            UtilsError::InvalidSeeds => 12003,
40            UtilsError::InvalidRolloverThreshold => 12004,
41        }
42    }
43}
44
45impl From<UtilsError> for solana_program::program_error::ProgramError {
46    fn from(e: UtilsError) -> Self {
47        solana_program::program_error::ProgramError::Custom(e.into())
48    }
49}
50
51pub fn is_smaller_than_bn254_field_size_be(bytes: &[u8; 32]) -> bool {
52    let bigint = BigUint::from_bytes_be(bytes);
53    bigint < ark_bn254::Fr::MODULUS.into()
54}
55
56pub fn hash_to_bn254_field_size_be(bytes: &[u8]) -> Option<([u8; 32], u8)> {
57    let mut bump_seed = [u8::MAX];
58    // Loops with decreasing bump seed to find a valid hash which is less than
59    // bn254 Fr modulo field size.
60    for _ in 0..u8::MAX {
61        {
62            let mut hashed_value: [u8; 32] = hashv(&[bytes, bump_seed.as_ref()]).to_bytes();
63            // Truncates to 31 bytes so that value is less than bn254 Fr modulo
64            // field size.
65            hashed_value[0] = 0;
66            if is_smaller_than_bn254_field_size_be(&hashed_value) {
67                return Some((hashed_value, bump_seed[0]));
68            }
69        }
70        bump_seed[0] -= 1;
71    }
72    None
73}
74
75/// Hashes the provided `bytes` with Keccak256 and ensures the result fits
76/// in the BN254 prime field by repeatedly hashing the inputs with various
77/// "bump seeds" and truncating the resulting hash to 31 bytes.
78///
79/// The attempted "bump seeds" are bytes from 255 to 0.
80///
81/// # Examples
82///
83/// ```
84/// use light_utils::hashv_to_bn254_field_size_be;
85///
86/// hashv_to_bn254_field_size_be(&[b"foo", b"bar"]);
87/// ```
88pub fn hashv_to_bn254_field_size_be(bytes: &[&[u8]]) -> [u8; 32] {
89    let mut hashed_value: [u8; 32] = hashv(bytes).to_bytes();
90    // Truncates to 31 bytes so that value is less than bn254 Fr modulo
91    // field size.
92    hashed_value[0] = 0;
93    hashed_value
94}
95
96/// Applies `rustfmt` on the given string containing Rust code. The purpose of
97/// this function is to be able to format autogenerated code (e.g. with `quote`
98/// macro).
99pub fn rustfmt(code: String) -> Result<Vec<u8>, anyhow::Error> {
100    let mut cmd = match env::var_os("RUSTFMT") {
101        Some(r) => Command::new(r),
102        None => Command::new("rustfmt"),
103    };
104
105    let mut cmd = cmd
106        .stdin(Stdio::piped())
107        .stdout(Stdio::piped())
108        .stderr(Stdio::piped())
109        .spawn()?;
110
111    let mut stdin = cmd.stdin.take().unwrap();
112    let mut stdout = cmd.stdout.take().unwrap();
113
114    let stdin_handle = spawn(move || {
115        stdin.write_all(code.as_bytes()).unwrap();
116    });
117
118    let mut formatted_code = vec![];
119    io::copy(&mut stdout, &mut formatted_code)?;
120
121    let _ = cmd.wait();
122    stdin_handle.join().unwrap();
123
124    Ok(formatted_code)
125}
126
127#[cfg(test)]
128mod tests {
129    use num_bigint::ToBigUint;
130    use solana_program::pubkey::Pubkey;
131
132    use crate::bigint::bigint_to_be_bytes_array;
133
134    use super::*;
135
136    #[test]
137    fn test_is_smaller_than_bn254_field_size_be() {
138        let modulus: BigUint = ark_bn254::Fr::MODULUS.into();
139        let modulus_bytes: [u8; 32] = bigint_to_be_bytes_array(&modulus).unwrap();
140        assert!(!is_smaller_than_bn254_field_size_be(&modulus_bytes));
141
142        let bigint = modulus.clone() - 1.to_biguint().unwrap();
143        let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
144        assert!(is_smaller_than_bn254_field_size_be(&bigint_bytes));
145
146        let bigint = modulus + 1.to_biguint().unwrap();
147        let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
148        assert!(!is_smaller_than_bn254_field_size_be(&bigint_bytes));
149    }
150
151    #[test]
152    fn test_hash_to_bn254_field_size_be() {
153        for _ in 0..10_000 {
154            let input_bytes = Pubkey::new_unique().to_bytes(); // Sample input
155            let (hashed_value, bump) = hash_to_bn254_field_size_be(input_bytes.as_slice())
156                .expect("Failed to find a hash within BN254 field size");
157            assert_eq!(bump, 255, "Bump seed should be 0");
158            assert!(
159                is_smaller_than_bn254_field_size_be(&hashed_value),
160                "Hashed value should be within BN254 field size"
161            );
162        }
163
164        let max_input = [u8::MAX; 32];
165        let (hashed_value, bump) = hash_to_bn254_field_size_be(max_input.as_slice())
166            .expect("Failed to find a hash within BN254 field size");
167        assert_eq!(bump, 255, "Bump seed should be 255");
168        assert!(
169            is_smaller_than_bn254_field_size_be(&hashed_value),
170            "Hashed value should be within BN254 field size"
171        );
172    }
173
174    #[test]
175    fn test_hashv_to_bn254_field_size_be() {
176        for _ in 0..10_000 {
177            let input_bytes = [Pubkey::new_unique().to_bytes(); 4];
178            let input_bytes = input_bytes.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
179            let hashed_value = hashv_to_bn254_field_size_be(input_bytes.as_slice());
180            assert!(
181                is_smaller_than_bn254_field_size_be(&hashed_value),
182                "Hashed value should be within BN254 field size"
183            );
184        }
185
186        let max_input = [[u8::MAX; 32]; 16];
187        let max_input = max_input.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
188        let hashed_value = hashv_to_bn254_field_size_be(max_input.as_slice());
189        assert!(
190            is_smaller_than_bn254_field_size_be(&hashed_value),
191            "Hashed value should be within BN254 field size"
192        );
193    }
194
195    #[test]
196    fn test_rustfmt() {
197        let unformatted_code = "use std::mem;
198
199fn main() {        println!(\"{}\", mem::size_of::<u64>()); }
200        "
201        .to_string();
202        let formatted_code = rustfmt(unformatted_code).unwrap();
203        assert_eq!(
204            String::from_utf8_lossy(&formatted_code),
205            "use std::mem;
206
207fn main() {
208    println!(\"{}\", mem::size_of::<u64>());
209}
210"
211        );
212    }
213}