1use std::{
2 env,
3 io::{self, prelude::*},
4 process::{Command, Stdio},
5 thread::spawn,
6};
7
8use ark_ff::PrimeField;
9use num_bigint::BigUint;
10
11use solana_program::keccak::hashv;
12use thiserror::Error;
13
14pub mod bigint;
15pub mod fee;
16pub mod offset;
17pub mod prime;
18pub mod rand;
19
20#[derive(Debug, Error, PartialEq)]
21pub enum UtilsError {
22 #[error("Invalid input size, expected at most {0}")]
23 InputTooLarge(usize),
24 #[error("Invalid chunk size")]
25 InvalidChunkSize,
26 #[error("Invalid seeds")]
27 InvalidSeeds,
28 #[error("Invalid rollover thresold")]
29 InvalidRolloverThreshold,
30}
31
32impl From<UtilsError> for u32 {
35 fn from(e: UtilsError) -> u32 {
36 match e {
37 UtilsError::InputTooLarge(_) => 12001,
38 UtilsError::InvalidChunkSize => 12002,
39 UtilsError::InvalidSeeds => 12003,
40 UtilsError::InvalidRolloverThreshold => 12004,
41 }
42 }
43}
44
45impl From<UtilsError> for solana_program::program_error::ProgramError {
46 fn from(e: UtilsError) -> Self {
47 solana_program::program_error::ProgramError::Custom(e.into())
48 }
49}
50
51pub fn is_smaller_than_bn254_field_size_be(bytes: &[u8; 32]) -> bool {
52 let bigint = BigUint::from_bytes_be(bytes);
53 bigint < ark_bn254::Fr::MODULUS.into()
54}
55
56pub fn hash_to_bn254_field_size_be(bytes: &[u8]) -> Option<([u8; 32], u8)> {
57 let mut bump_seed = [u8::MAX];
58 for _ in 0..u8::MAX {
61 {
62 let mut hashed_value: [u8; 32] = hashv(&[bytes, bump_seed.as_ref()]).to_bytes();
63 hashed_value[0] = 0;
66 if is_smaller_than_bn254_field_size_be(&hashed_value) {
67 return Some((hashed_value, bump_seed[0]));
68 }
69 }
70 bump_seed[0] -= 1;
71 }
72 None
73}
74
75pub fn hashv_to_bn254_field_size_be(bytes: &[&[u8]]) -> [u8; 32] {
89 let mut hashed_value: [u8; 32] = hashv(bytes).to_bytes();
90 hashed_value[0] = 0;
93 hashed_value
94}
95
96pub fn rustfmt(code: String) -> Result<Vec<u8>, anyhow::Error> {
100 let mut cmd = match env::var_os("RUSTFMT") {
101 Some(r) => Command::new(r),
102 None => Command::new("rustfmt"),
103 };
104
105 let mut cmd = cmd
106 .stdin(Stdio::piped())
107 .stdout(Stdio::piped())
108 .stderr(Stdio::piped())
109 .spawn()?;
110
111 let mut stdin = cmd.stdin.take().unwrap();
112 let mut stdout = cmd.stdout.take().unwrap();
113
114 let stdin_handle = spawn(move || {
115 stdin.write_all(code.as_bytes()).unwrap();
116 });
117
118 let mut formatted_code = vec![];
119 io::copy(&mut stdout, &mut formatted_code)?;
120
121 let _ = cmd.wait();
122 stdin_handle.join().unwrap();
123
124 Ok(formatted_code)
125}
126
127#[cfg(test)]
128mod tests {
129 use num_bigint::ToBigUint;
130 use solana_program::pubkey::Pubkey;
131
132 use crate::bigint::bigint_to_be_bytes_array;
133
134 use super::*;
135
136 #[test]
137 fn test_is_smaller_than_bn254_field_size_be() {
138 let modulus: BigUint = ark_bn254::Fr::MODULUS.into();
139 let modulus_bytes: [u8; 32] = bigint_to_be_bytes_array(&modulus).unwrap();
140 assert!(!is_smaller_than_bn254_field_size_be(&modulus_bytes));
141
142 let bigint = modulus.clone() - 1.to_biguint().unwrap();
143 let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
144 assert!(is_smaller_than_bn254_field_size_be(&bigint_bytes));
145
146 let bigint = modulus + 1.to_biguint().unwrap();
147 let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
148 assert!(!is_smaller_than_bn254_field_size_be(&bigint_bytes));
149 }
150
151 #[test]
152 fn test_hash_to_bn254_field_size_be() {
153 for _ in 0..10_000 {
154 let input_bytes = Pubkey::new_unique().to_bytes(); let (hashed_value, bump) = hash_to_bn254_field_size_be(input_bytes.as_slice())
156 .expect("Failed to find a hash within BN254 field size");
157 assert_eq!(bump, 255, "Bump seed should be 0");
158 assert!(
159 is_smaller_than_bn254_field_size_be(&hashed_value),
160 "Hashed value should be within BN254 field size"
161 );
162 }
163
164 let max_input = [u8::MAX; 32];
165 let (hashed_value, bump) = hash_to_bn254_field_size_be(max_input.as_slice())
166 .expect("Failed to find a hash within BN254 field size");
167 assert_eq!(bump, 255, "Bump seed should be 255");
168 assert!(
169 is_smaller_than_bn254_field_size_be(&hashed_value),
170 "Hashed value should be within BN254 field size"
171 );
172 }
173
174 #[test]
175 fn test_hashv_to_bn254_field_size_be() {
176 for _ in 0..10_000 {
177 let input_bytes = [Pubkey::new_unique().to_bytes(); 4];
178 let input_bytes = input_bytes.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
179 let hashed_value = hashv_to_bn254_field_size_be(input_bytes.as_slice());
180 assert!(
181 is_smaller_than_bn254_field_size_be(&hashed_value),
182 "Hashed value should be within BN254 field size"
183 );
184 }
185
186 let max_input = [[u8::MAX; 32]; 16];
187 let max_input = max_input.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
188 let hashed_value = hashv_to_bn254_field_size_be(max_input.as_slice());
189 assert!(
190 is_smaller_than_bn254_field_size_be(&hashed_value),
191 "Hashed value should be within BN254 field size"
192 );
193 }
194
195 #[test]
196 fn test_rustfmt() {
197 let unformatted_code = "use std::mem;
198
199fn main() { println!(\"{}\", mem::size_of::<u64>()); }
200 "
201 .to_string();
202 let formatted_code = rustfmt(unformatted_code).unwrap();
203 assert_eq!(
204 String::from_utf8_lossy(&formatted_code),
205 "use std::mem;
206
207fn main() {
208 println!(\"{}\", mem::size_of::<u64>());
209}
210"
211 );
212 }
213}