1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
use std::{
    env,
    io::{self, prelude::*},
    process::{Command, Stdio},
    thread::spawn,
};

use ark_ff::PrimeField;
use num_bigint::BigUint;

use solana_program::keccak::hashv;
use thiserror::Error;

pub mod bigint;
pub mod fee;
pub mod offset;
pub mod prime;
pub mod rand;

#[derive(Debug, Error, PartialEq)]
pub enum UtilsError {
    #[error("Invalid input size, expected at most {0}")]
    InputTooLarge(usize),
    #[error("Invalid chunk size")]
    InvalidChunkSize,
    #[error("Invalid seeds")]
    InvalidSeeds,
    #[error("Invalid rollover thresold")]
    InvalidRolloverThreshold,
}

// NOTE(vadorovsky): Unfortunately, we need to do it by hand.
// `num_derive::ToPrimitive` doesn't support data-carrying enums.
impl From<UtilsError> for u32 {
    fn from(e: UtilsError) -> u32 {
        match e {
            UtilsError::InputTooLarge(_) => 12001,
            UtilsError::InvalidChunkSize => 12002,
            UtilsError::InvalidSeeds => 12003,
            UtilsError::InvalidRolloverThreshold => 12004,
        }
    }
}

impl From<UtilsError> for solana_program::program_error::ProgramError {
    fn from(e: UtilsError) -> Self {
        solana_program::program_error::ProgramError::Custom(e.into())
    }
}

pub fn is_smaller_than_bn254_field_size_be(bytes: &[u8; 32]) -> bool {
    let bigint = BigUint::from_bytes_be(bytes);
    bigint < ark_bn254::Fr::MODULUS.into()
}

pub fn hash_to_bn254_field_size_be(bytes: &[u8]) -> Option<([u8; 32], u8)> {
    let mut bump_seed = [u8::MAX];
    // Loops with decreasing bump seed to find a valid hash which is less than
    // bn254 Fr modulo field size.
    for _ in 0..u8::MAX {
        {
            let mut hashed_value: [u8; 32] = hashv(&[bytes, bump_seed.as_ref()]).to_bytes();
            // Truncates to 31 bytes so that value is less than bn254 Fr modulo
            // field size.
            hashed_value[0] = 0;
            if is_smaller_than_bn254_field_size_be(&hashed_value) {
                return Some((hashed_value, bump_seed[0]));
            }
        }
        bump_seed[0] -= 1;
    }
    None
}

/// Applies `rustfmt` on the given string containing Rust code. The purpose of
/// this function is to be able to format autogenerated code (e.g. with `quote`
/// macro).
pub fn rustfmt(code: String) -> Result<Vec<u8>, anyhow::Error> {
    let mut cmd = match env::var_os("RUSTFMT") {
        Some(r) => Command::new(r),
        None => Command::new("rustfmt"),
    };

    let mut cmd = cmd
        .stdin(Stdio::piped())
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .spawn()?;

    let mut stdin = cmd.stdin.take().unwrap();
    let mut stdout = cmd.stdout.take().unwrap();

    let stdin_handle = spawn(move || {
        stdin.write_all(code.as_bytes()).unwrap();
    });

    let mut formatted_code = vec![];
    io::copy(&mut stdout, &mut formatted_code)?;

    let _ = cmd.wait();
    stdin_handle.join().unwrap();

    Ok(formatted_code)
}

#[cfg(test)]
mod tests {
    use num_bigint::ToBigUint;
    use solana_program::pubkey::Pubkey;

    use crate::bigint::bigint_to_be_bytes_array;

    use super::*;

    #[test]
    fn test_is_smaller_than_bn254_field_size_be() {
        let modulus: BigUint = ark_bn254::Fr::MODULUS.into();
        let modulus_bytes: [u8; 32] = bigint_to_be_bytes_array(&modulus).unwrap();
        assert!(!is_smaller_than_bn254_field_size_be(&modulus_bytes));

        let bigint = modulus.clone() - 1.to_biguint().unwrap();
        let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
        assert!(is_smaller_than_bn254_field_size_be(&bigint_bytes));

        let bigint = modulus + 1.to_biguint().unwrap();
        let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap();
        assert!(!is_smaller_than_bn254_field_size_be(&bigint_bytes));
    }

    #[test]
    fn test_hash_to_bn254_field_size_be() {
        for _ in 0..10_000 {
            let input_bytes = Pubkey::new_unique().to_bytes(); // Sample input
            let (hashed_value, bump) = hash_to_bn254_field_size_be(input_bytes.as_slice())
                .expect("Failed to find a hash within BN254 field size");
            assert_eq!(bump, 255, "Bump seed should be 0");
            assert!(
                is_smaller_than_bn254_field_size_be(&hashed_value),
                "Hashed value should be within BN254 field size"
            );
        }

        let max_input = [u8::MAX; 32];
        let (hashed_value, bump) = hash_to_bn254_field_size_be(max_input.as_slice())
            .expect("Failed to find a hash within BN254 field size");
        assert_eq!(bump, 255, "Bump seed should be 255");
        assert!(
            is_smaller_than_bn254_field_size_be(&hashed_value),
            "Hashed value should be within BN254 field size"
        );
    }

    #[test]
    fn test_rustfmt() {
        let unformatted_code = "use std::mem;

fn main() {        println!(\"{}\", mem::size_of::<u64>()); }
        "
        .to_string();
        let formatted_code = rustfmt(unformatted_code).unwrap();
        assert_eq!(
            String::from_utf8_lossy(&formatted_code),
            "use std::mem;

fn main() {
    println!(\"{}\", mem::size_of::<u64>());
}
"
        );
    }
}