atlas_big_mod_exp/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#[repr(C)]
3pub struct BigModExpParams {
4    pub base: *const u8,
5    pub base_len: u64,
6    pub exponent: *const u8,
7    pub exponent_len: u64,
8    pub modulus: *const u8,
9    pub modulus_len: u64,
10}
11
12/// Big integer modular exponentiation
13pub fn big_mod_exp(base: &[u8], exponent: &[u8], modulus: &[u8]) -> Vec<u8> {
14    #[cfg(not(target_os = "atlas"))]
15    {
16        use {
17            num_bigint::BigUint,
18            num_traits::{One, Zero},
19        };
20
21        let modulus_len = modulus.len();
22        let base = BigUint::from_bytes_be(base);
23        let exponent = BigUint::from_bytes_be(exponent);
24        let modulus = BigUint::from_bytes_be(modulus);
25
26        if modulus.is_zero() || modulus.is_one() {
27            return vec![0_u8; modulus_len];
28        }
29
30        let ret_int = base.modpow(&exponent, &modulus);
31        let ret_int = ret_int.to_bytes_be();
32        let mut return_value = vec![0_u8; modulus_len.saturating_sub(ret_int.len())];
33        return_value.extend(ret_int);
34        return_value
35    }
36
37    #[cfg(target_os = "atlas")]
38    {
39        let mut return_value = vec![0_u8; modulus.len()];
40
41        let param = BigModExpParams {
42            base: base as *const _ as *const u8,
43            base_len: base.len() as u64,
44            exponent: exponent as *const _ as *const u8,
45            exponent_len: exponent.len() as u64,
46            modulus: modulus as *const _ as *const u8,
47            modulus_len: modulus.len() as u64,
48        };
49        unsafe {
50            atlas_define_syscall::definitions::sol_big_mod_exp(
51                &param as *const _ as *const u8,
52                return_value.as_mut_slice() as *mut _ as *mut u8,
53            )
54        };
55
56        return_value
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn big_mod_exp_test() {
66        #[derive(serde_derive::Deserialize)]
67        #[serde(rename_all = "PascalCase")]
68        struct TestCase {
69            base: String,
70            exponent: String,
71            modulus: String,
72            expected: String,
73        }
74
75        let test_data = include_str!("../tests/data/big_mod_exp_cases.json");
76
77        let test_cases: Vec<TestCase> = serde_json::from_str(test_data).unwrap();
78        test_cases.iter().for_each(|test| {
79            let base = array_bytes::hex2bytes_unchecked(&test.base);
80            let exponent = array_bytes::hex2bytes_unchecked(&test.exponent);
81            let modulus = array_bytes::hex2bytes_unchecked(&test.modulus);
82            let expected = array_bytes::hex2bytes_unchecked(&test.expected);
83            let result = big_mod_exp(base.as_slice(), exponent.as_slice(), modulus.as_slice());
84            assert_eq!(result, expected);
85        });
86    }
87}