1#![cfg_attr(docsrs, feature(doc_cfg))]
2#[repr(C)]
3pub struct BigModExpParams {
4 pub base: *const u8,
5 pub base_len: u64,
6 pub exponent: *const u8,
7 pub exponent_len: u64,
8 pub modulus: *const u8,
9 pub modulus_len: u64,
10}
11
12pub fn big_mod_exp(base: &[u8], exponent: &[u8], modulus: &[u8]) -> Vec<u8> {
14 #[cfg(not(target_os = "atlas"))]
15 {
16 use {
17 num_bigint::BigUint,
18 num_traits::{One, Zero},
19 };
20
21 let modulus_len = modulus.len();
22 let base = BigUint::from_bytes_be(base);
23 let exponent = BigUint::from_bytes_be(exponent);
24 let modulus = BigUint::from_bytes_be(modulus);
25
26 if modulus.is_zero() || modulus.is_one() {
27 return vec![0_u8; modulus_len];
28 }
29
30 let ret_int = base.modpow(&exponent, &modulus);
31 let ret_int = ret_int.to_bytes_be();
32 let mut return_value = vec![0_u8; modulus_len.saturating_sub(ret_int.len())];
33 return_value.extend(ret_int);
34 return_value
35 }
36
37 #[cfg(target_os = "atlas")]
38 {
39 let mut return_value = vec![0_u8; modulus.len()];
40
41 let param = BigModExpParams {
42 base: base as *const _ as *const u8,
43 base_len: base.len() as u64,
44 exponent: exponent as *const _ as *const u8,
45 exponent_len: exponent.len() as u64,
46 modulus: modulus as *const _ as *const u8,
47 modulus_len: modulus.len() as u64,
48 };
49 unsafe {
50 atlas_define_syscall::definitions::sol_big_mod_exp(
51 ¶m as *const _ as *const u8,
52 return_value.as_mut_slice() as *mut _ as *mut u8,
53 )
54 };
55
56 return_value
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn big_mod_exp_test() {
66 #[derive(serde_derive::Deserialize)]
67 #[serde(rename_all = "PascalCase")]
68 struct TestCase {
69 base: String,
70 exponent: String,
71 modulus: String,
72 expected: String,
73 }
74
75 let test_data = include_str!("../tests/data/big_mod_exp_cases.json");
76
77 let test_cases: Vec<TestCase> = serde_json::from_str(test_data).unwrap();
78 test_cases.iter().for_each(|test| {
79 let base = array_bytes::hex2bytes_unchecked(&test.base);
80 let exponent = array_bytes::hex2bytes_unchecked(&test.exponent);
81 let modulus = array_bytes::hex2bytes_unchecked(&test.modulus);
82 let expected = array_bytes::hex2bytes_unchecked(&test.expected);
83 let result = big_mod_exp(base.as_slice(), exponent.as_slice(), modulus.as_slice());
84 assert_eq!(result, expected);
85 });
86 }
87}