use crate::tee_api_defines::{TEE_ERROR_OVERFLOW, TEE_SUCCESS};
use crate::tee_api_types::{TEE_BigInt, TEE_BigIntFMM, TEE_BigIntFMMContext, TEE_Result};
use mbedtls::bignum::Mpi;
use mbedtls::error::Error;
use mbedtls::rng::RngCallback;
pub use mbedtls_sys_auto::mpi_sint;
use std::ops::ShrAssign;
use mbedtls_sys_auto as mbedtls_sys;
#[repr(C)]
struct BigintHdr {
pub sign: i32, pub alloc_size: u16, pub nblimbs: u16, }
pub const BIGINT_HDR_SIZE_IN_U32: usize = 2;
const CFG_TA_BIGNUM_MAX_BITS: usize = 4096;
pub trait TeeBigIntExt {
unsafe fn to_teebigint(&self, bigint: *mut TEE_BigInt, alloc_size: usize) -> Result<(), Error>;
unsafe fn from_teebigint(bigint: *const TEE_BigInt) -> Result<Self, Error>
where
Self: Sized;
}
impl TeeBigIntExt for Mpi {
unsafe fn to_teebigint(&self, bigint: *mut TEE_BigInt, alloc_size: usize) -> Result<(), Error> {
if bigint.is_null() {
return Err(mbedtls::error::codes::MpiBadInputData.into());
}
let mpi_limbs_count = {
let handle: *const mbedtls_sys::mpi = self.into();
unsafe {
let mut limbs_count = (*handle).n;
while limbs_count > 0 && self.get_limb(limbs_count - 1) == 0 {
limbs_count -= 1;
}
limbs_count
}
};
let mut tee_limbs_count = 0;
for i in 0..mpi_limbs_count {
let limb = self.get_limb(i);
if i == mpi_limbs_count - 1 {
if limb == 0 {
continue; } else if limb <= 0xFFFFFFFF {
tee_limbs_count += 1;
} else {
tee_limbs_count += 2;
}
} else {
tee_limbs_count += 2;
}
}
if alloc_size < tee_limbs_count {
return Err(mbedtls::error::codes::MpiBufferTooSmall.into());
}
unsafe {
let header = bigint as *mut BigintHdr;
let handle: *const mbedtls_sys::mpi = self.into();
(*header).sign = (*handle).s; (*header).alloc_size = alloc_size as u16; (*header).nblimbs = tee_limbs_count as u16;
let mut tee_index = 0;
for i in 0..mpi_limbs_count {
let limb = self.get_limb(i);
let low_ptr = bigint.add(2 + tee_index) as *mut u32;
*low_ptr = (limb & 0xFFFFFFFF) as u32;
tee_index += 1;
if i < mpi_limbs_count - 1 || limb > 0xFFFFFFFF {
let high_ptr = bigint.add(2 + tee_index) as *mut u32;
*high_ptr = ((limb >> 32) & 0xFFFFFFFF) as u32;
tee_index += 1;
}
}
}
Ok(())
}
unsafe fn from_teebigint(bigint: *const TEE_BigInt) -> Result<Self, Error> {
if bigint.is_null() {
return Mpi::new(0);
}
let (sign, nblimbs) = unsafe {
let header = bigint as *const BigintHdr;
((*header).sign, (*header).nblimbs as usize)
};
if nblimbs == 0 {
return Mpi::new(0);
}
let mut data_vec = Vec::with_capacity((nblimbs + 1) / 2);
let mut i = 0;
while i < nblimbs {
let low = unsafe {
let low_ptr = bigint.add(2 + i) as *const u32;
if low_ptr.is_null() { 0 } else { *low_ptr }
};
if i + 1 < nblimbs {
let high = unsafe {
let high_ptr = bigint.add(2 + i + 1) as *const u32;
if high_ptr.is_null() { 0 } else { *high_ptr }
};
let combined = ((high as u64) << 32) | (low as u64);
data_vec.push(combined as mbedtls_sys::mpi_uint);
i += 2;
} else {
data_vec.push(low as mbedtls_sys::mpi_uint);
i += 1;
}
}
let trimmed_data = {
let mut len = data_vec.len();
while len > 0 && data_vec[len - 1] == 0 {
len -= 1;
}
&data_vec[..len]
};
if trimmed_data.is_empty() {
return Mpi::new(0);
}
let mut mpi = Mpi::new(0)?;
unsafe {
let handle: *mut mbedtls_sys::mpi = (&mut mpi).into();
let result = mbedtls_sys::mpi_grow(handle, trimmed_data.len());
if result != 0 {
return Err(mbedtls::error::codes::MpiBadInputData.into());
}
(*handle).s = sign;
let dst_ptr = (*handle).p;
std::ptr::copy_nonoverlapping(trimmed_data.as_ptr(), dst_ptr, trimmed_data.len());
}
Ok(mpi)
}
}
trait MpiExt {
fn get_limb(&self, n: usize) -> mbedtls_sys::mpi_uint;
}
impl MpiExt for Mpi {
fn get_limb(&self, n: usize) -> mbedtls_sys::mpi_uint {
let handle: *const mbedtls_sys::mpi = self.into();
let n_limbs = unsafe { (*handle).n };
if n < n_limbs {
unsafe { *(*handle).p.offset(n as isize) }
} else {
0
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntInit(big_int: *mut TEE_BigInt, len: usize) {
if len > CFG_TA_BIGNUM_MAX_BITS / 4 {
panic!("Too large bigint");
}
let alloc_size = (len - BIGINT_HDR_SIZE_IN_U32) as u16;
unsafe {
core::ptr::write_bytes(big_int as *mut u8, 0, len * 4);
let hdr = big_int as *mut BigintHdr;
(*hdr).sign = 1;
(*hdr).alloc_size = alloc_size;
(*hdr).nblimbs = 0;
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertFromOctetString(
dest: *mut TEE_BigInt,
buffer: *const u8,
buffer_len: usize,
sign: i32,
) -> TEE_Result {
let buffer_slice = unsafe { core::slice::from_raw_parts(buffer, buffer_len) };
match Mpi::from_binary(buffer_slice) {
Ok(mut mpi) => {
if sign < 0 {
match Mpi::new(-1) {
Ok(neg_one) => {
unsafe {
let result = mbedtls_sys::mpi_mul_mpi(
(&mut mpi).into(),
(&mpi).into(),
(&neg_one).into(),
);
if result != 0 {
return TEE_ERROR_OVERFLOW;
}
}
}
Err(_) => return TEE_ERROR_OVERFLOW,
}
}
unsafe {
let hdr = dest as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
match mpi.to_teebigint(dest, alloc_size) {
Ok(()) => TEE_SUCCESS,
Err(_) => TEE_ERROR_OVERFLOW,
}
}
}
Err(_) => TEE_ERROR_OVERFLOW,
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertToOctetString(
buffer: *mut u8,
buffer_len: *mut usize,
big_int: *const TEE_BigInt,
) -> TEE_Result {
if buffer_len.is_null() || big_int.is_null() {
return TEE_ERROR_OVERFLOW; }
let mpi = match unsafe { Mpi::from_teebigint(big_int) } {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let sz = match mpi.byte_length() {
Ok(len) => len,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let provided_buffer_len = unsafe { *buffer_len };
if sz <= provided_buffer_len {
if !buffer.is_null() {
match mpi.to_binary() {
Ok(binary_data) => {
unsafe {
core::ptr::copy_nonoverlapping(
binary_data.as_ptr(),
buffer,
binary_data.len(),
);
}
}
Err(_) => return TEE_ERROR_OVERFLOW,
}
}
} else {
unsafe { *buffer_len = sz };
return TEE_ERROR_OVERFLOW; }
unsafe { *buffer_len = sz };
TEE_SUCCESS
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertFromS32(dest: *mut TEE_BigInt, short_val: i32) {
unsafe {
let hdr = dest as *mut BigintHdr;
if short_val < 0 {
(*hdr).sign = -1;
} else {
(*hdr).sign = 1;
}
let abs_val = if short_val < 0 {
-(short_val as i64) as u32
} else {
short_val as u32
};
let data_ptr = dest.add(2) as *mut u32;
let alloc_size = (*hdr).alloc_size as usize;
for i in 0..alloc_size {
*data_ptr.add(i) = 0;
}
if alloc_size > 0 {
*data_ptr = abs_val;
(*hdr).nblimbs = if abs_val == 0 { 0 } else { 1 };
} else {
(*hdr).nblimbs = 0;
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertToS32(dest: *mut i32, src: *const TEE_BigInt) -> TEE_Result {
let mpi = match unsafe { Mpi::from_teebigint(src) } {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
match mpi.to_binary() {
Ok(binary_data) => {
if binary_data.len() > 4 {
return TEE_ERROR_OVERFLOW;
}
let mut v: u32 = 0;
for &byte in &binary_data {
v = (v << 8) | byte as u32;
}
if mpi.sign() == mbedtls::bignum::Sign::Positive {
if v > i32::MAX as u32 {
return TEE_ERROR_OVERFLOW;
}
unsafe {
*dest = v as i32;
}
} else {
if v > (i32::MAX as u32 + 1) {
return TEE_ERROR_OVERFLOW;
}
unsafe {
*dest = if v as i32 == i32::MIN {
i32::MIN
} else {
-(v as i32)
};
}
}
TEE_SUCCESS
}
Err(_) => TEE_ERROR_OVERFLOW,
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntCmp(op1: *const TEE_BigInt, op2: *const TEE_BigInt) -> i32 {
let mpi1 = match unsafe { Mpi::from_teebigint(op1) } {
Ok(mpi) => mpi,
Err(_) => return 0, };
let mpi2 = match unsafe { Mpi::from_teebigint(op2) } {
Ok(mpi) => mpi,
Err(_) => return 0, };
match mpi1.cmp(&mpi2) {
std::cmp::Ordering::Less => -1,
std::cmp::Ordering::Equal => 0,
std::cmp::Ordering::Greater => 1,
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntCmpS32(src: *const TEE_BigInt, short_val: i32) -> i32 {
let mpi = match unsafe { Mpi::from_teebigint(src) } {
Ok(mpi) => mpi,
Err(_) => return 0, };
let cmp_mpi = match Mpi::new(short_val as mpi_sint) {
Ok(mpi) => mpi,
Err(_) => return 0, };
match mpi.cmp(&cmp_mpi) {
std::cmp::Ordering::Less => -1,
std::cmp::Ordering::Equal => 0,
std::cmp::Ordering::Greater => 1,
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntShiftRight(dest: *mut TEE_BigInt, op: *const TEE_BigInt, bits: usize) {
let mut temp_mpi = match unsafe { Mpi::from_teebigint(op) } {
Ok(mpi) => mpi,
Err(_) => return, };
temp_mpi.shr_assign(bits);
let dest_info = unsafe {
let hdr = dest as *mut BigintHdr;
(*hdr).alloc_size as usize
};
unsafe {
match temp_mpi.to_teebigint(dest, dest_info) {
Ok(_) => {
}
Err(_) => {
let hdr = dest as *mut BigintHdr;
(*hdr).sign = 0;
(*hdr).nblimbs = 0;
let data_ptr = dest.add(2) as *mut u32;
for i in 0..dest_info {
*data_ptr.add(i) = 0;
}
}
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntGetBit(src: *const TEE_BigInt, bit_index: u32) -> bool {
let mpi = match unsafe { Mpi::from_teebigint(src) } {
Ok(mpi) => mpi,
Err(_) => return false, };
mpi.get_bit(bit_index as usize)
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntGetBitCount(src: *const TEE_BigInt) -> u32 {
let mpi = match unsafe { Mpi::from_teebigint(src) } {
Ok(mpi) => mpi,
Err(_) => return 0, };
match mpi.bit_length() {
Ok(len) => len as u32,
Err(_) => 0, }
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntSetBit(op: *mut TEE_BigInt, bit_index: u32, value: bool) -> TEE_Result {
let mut mpi = match unsafe { Mpi::from_teebigint(op as *const TEE_BigInt) } {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
match mpi.set_bit(bit_index as usize, value) {
Ok(()) => {
unsafe {
let hdr = op as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
match mpi.to_teebigint(op, alloc_size) {
Ok(()) => TEE_SUCCESS,
Err(_) => TEE_ERROR_OVERFLOW,
}
}
}
Err(_) => TEE_ERROR_OVERFLOW,
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntAssign(dest: *mut TEE_BigInt, src: *const TEE_BigInt) -> TEE_Result {
if dest == src as *mut TEE_BigInt {
return TEE_SUCCESS;
}
if dest.is_null() || src.is_null() {
return TEE_ERROR_OVERFLOW;
}
unsafe {
let src_hdr = src as *const BigintHdr;
let dst_hdr = dest as *mut BigintHdr;
if (*dst_hdr).alloc_size < (*src_hdr).nblimbs {
return TEE_ERROR_OVERFLOW;
}
let src_slice = core::slice::from_raw_parts(
(src as *const u32).add(BIGINT_HDR_SIZE_IN_U32),
(*src_hdr).nblimbs as usize,
);
let dst_slice = core::slice::from_raw_parts_mut(
dest.add(BIGINT_HDR_SIZE_IN_U32),
(*src_hdr).nblimbs as usize,
);
(*dst_hdr).nblimbs = (*src_hdr).nblimbs;
(*dst_hdr).sign = (*src_hdr).sign;
dst_slice.copy_from_slice(src_slice);
}
TEE_SUCCESS
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntAbs(dest: *mut TEE_BigInt, src: *const TEE_BigInt) -> TEE_Result {
let res = TEE_BigIntAssign(dest, src);
if res == TEE_SUCCESS {
unsafe {
let dst_hdr = dest as *mut BigintHdr;
(*dst_hdr).sign = 1; }
}
res
}
fn bigint_binary(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
func: unsafe extern "C" fn(
*mut mbedtls_sys_auto::mpi,
*const mbedtls_sys_auto::mpi,
*const mbedtls_sys_auto::mpi,
) -> i32,
) -> TEE_Result {
unsafe {
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = (*dst_hdr).alloc_size as usize;
let mpi_op1 = if op1 == dest as *const TEE_BigInt {
None } else {
match Mpi::from_teebigint(op1) {
Ok(mpi) => Some(mpi),
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mpi_op2 = if op2 == dest as *const TEE_BigInt {
None } else if op2 == op1 {
mpi_op1.clone() } else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => Some(mpi),
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mut mpi_dest = match Mpi::from_teebigint(dest as *const TEE_BigInt) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let result = if op1 == dest as *const TEE_BigInt && op2 == dest as *const TEE_BigInt {
func(
(&mut mpi_dest).into(),
(&mpi_dest).into(),
(&mpi_dest).into(),
)
} else if op1 == dest as *const TEE_BigInt {
if op2 == op1 {
func(
(&mut mpi_dest).into(),
(&mpi_dest).into(),
(&mpi_dest).into(),
)
} else {
func(
(&mut mpi_dest).into(),
(&mpi_dest).into(),
mpi_op2.as_ref().unwrap().into(),
)
}
} else if op2 == dest as *const TEE_BigInt {
func(
(&mut mpi_dest).into(),
mpi_op1.as_ref().unwrap().into(),
(&mpi_dest).into(),
)
} else {
if op2 == op1 {
let op1_handle = mpi_op1.as_ref().unwrap().into();
func((&mut mpi_dest).into(), op1_handle, op1_handle)
} else {
func(
(&mut mpi_dest).into(),
mpi_op1.as_ref().unwrap().into(),
mpi_op2.as_ref().unwrap().into(),
)
}
};
if result != 0 {
return TEE_ERROR_OVERFLOW;
}
match mpi_dest.to_teebigint(dest, alloc_size) {
Ok(()) => TEE_SUCCESS,
Err(_) => TEE_ERROR_OVERFLOW,
}
}
}
fn bigint_binary_mod(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
n: *const TEE_BigInt,
func: unsafe extern "C" fn(
*mut mbedtls_sys_auto::mpi,
*const mbedtls_sys_auto::mpi,
*const mbedtls_sys_auto::mpi,
) -> i32,
) -> TEE_Result {
unsafe {
if TEE_BigIntCmpS32(n, 2) < 0 {
panic!("Modulus is too short");
}
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = (*dst_hdr).alloc_size as usize;
let mpi_n = match Mpi::from_teebigint(n) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let mpi_op1 = if op1 == dest as *const TEE_BigInt {
None } else {
match Mpi::from_teebigint(op1) {
Ok(mpi) => Some(mpi),
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mpi_op2 = if op2 == dest as *const TEE_BigInt {
None } else if op2 == op1 {
mpi_op1.clone() } else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => Some(mpi),
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mut mpi_dest = match Mpi::from_teebigint(dest as *const TEE_BigInt) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let mut mpi_t = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let result = if op1 == dest as *const TEE_BigInt && op2 == dest as *const TEE_BigInt {
func((&mut mpi_t).into(), (&mpi_dest).into(), (&mpi_dest).into())
} else if op1 == dest as *const TEE_BigInt {
if op2 == op1 {
func((&mut mpi_t).into(), (&mpi_dest).into(), (&mpi_dest).into())
} else {
func(
(&mut mpi_t).into(),
(&mpi_dest).into(),
mpi_op2.as_ref().unwrap().into(),
)
}
} else if op2 == dest as *const TEE_BigInt {
func(
(&mut mpi_t).into(),
mpi_op1.as_ref().unwrap().into(),
(&mpi_dest).into(),
)
} else {
if op2 == op1 {
let op1_handle = mpi_op1.as_ref().unwrap().into();
func((&mut mpi_t).into(), op1_handle, op1_handle)
} else {
func(
(&mut mpi_t).into(),
mpi_op1.as_ref().unwrap().into(),
mpi_op2.as_ref().unwrap().into(),
)
}
};
if result != 0 {
return TEE_ERROR_OVERFLOW;
}
let mod_result =
mbedtls_sys::mpi_mod_mpi((&mut mpi_dest).into(), (&mpi_t).into(), (&mpi_n).into());
if mod_result != 0 {
return TEE_ERROR_OVERFLOW;
}
match mpi_dest.to_teebigint(dest, alloc_size) {
Ok(()) => TEE_SUCCESS,
Err(_) => TEE_ERROR_OVERFLOW,
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntAdd(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
) {
let _ = bigint_binary(dest, op1, op2, mbedtls_sys_auto::mpi_add_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntSub(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
) {
let _ = bigint_binary(dest, op1, op2, mbedtls_sys_auto::mpi_sub_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntNeg(dest: *mut TEE_BigInt, src: *const TEE_BigInt) {
unsafe {
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = (*dst_hdr).alloc_size as usize;
let mut mpi_src = if dest == src as *mut TEE_BigInt {
match Mpi::from_teebigint(src) {
Ok(mpi) => mpi,
Err(_) => return, }
} else {
match Mpi::from_teebigint(src) {
Ok(mpi) => mpi,
Err(_) => return, }
};
let handle: *mut mbedtls_sys::mpi = (&mut mpi_src).into();
(*handle).s *= -1;
let _ = mpi_src.to_teebigint(dest, alloc_size);
}
}
fn tee_big_int_size_in_u32(n: usize) -> usize {
((n + 31) / 32) + BIGINT_HDR_SIZE_IN_U32
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntMul(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
) {
let bs1 = TEE_BigIntGetBitCount(op1);
let bs2 = TEE_BigIntGetBitCount(op2);
let s = tee_big_int_size_in_u32(bs1 as usize) + tee_big_int_size_in_u32(bs2 as usize);
let mut tmp_storage = vec![0u32; s];
let tmp = tmp_storage.as_mut_ptr();
TEE_BigIntInit(tmp, s);
let _ = bigint_binary(tmp, op1, op2, mbedtls_sys_auto::mpi_mul_mpi);
let zero_storage = [0u32; BIGINT_HDR_SIZE_IN_U32 + 1];
let zero = zero_storage.as_ptr();
TEE_BigIntInit(zero as *mut TEE_BigInt, BIGINT_HDR_SIZE_IN_U32 + 1);
TEE_BigIntAdd(dest, tmp, zero);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntSquare(dest: *mut TEE_BigInt, op: *const TEE_BigInt) {
TEE_BigIntMul(dest, op, op);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntDiv(
dest_q: *mut TEE_BigInt,
dest_r: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
) {
unsafe {
let zero_check = Mpi::from_teebigint(op2);
if let Ok(ref mpi_op2) = zero_check {
let is_zero = match mpi_op2.to_binary() {
Ok(binary_data) => binary_data.iter().all(|&x| x == 0),
Err(_) => true, };
if is_zero {
panic!("Division by zero");
}
}
let q_alloc_size = if !dest_q.is_null() {
let q_hdr = dest_q as *mut BigintHdr;
Some((*q_hdr).alloc_size as usize)
} else {
None
};
let r_alloc_size = if !dest_r.is_null() {
let r_hdr = dest_r as *mut BigintHdr;
Some((*r_hdr).alloc_size as usize)
} else {
None
};
let mpi_op1 = if op1 == dest_q || op1 == dest_r {
match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return, }
} else {
match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return, }
};
let mpi_op2 = if op2 == op1 {
mpi_op1.clone()
} else if op2 == dest_q || op2 == dest_r {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return, }
} else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return, }
};
let mut mpi_dest_q = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return, };
let mut mpi_dest_r = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return, };
let result = mbedtls_sys::mpi_div_mpi(
(&mut mpi_dest_q).into(),
(&mut mpi_dest_r).into(),
(&mpi_op1).into(),
(&mpi_op2).into(),
);
if result != 0 {
return; }
if !dest_q.is_null() {
let _ = mpi_dest_q.to_teebigint(dest_q, q_alloc_size.unwrap());
}
if !dest_r.is_null() {
let _ = mpi_dest_r.to_teebigint(dest_r, r_alloc_size.unwrap());
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntMod(
dest: *mut TEE_BigInt,
op: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
if TEE_BigIntCmpS32(n, 2) < 0 {
panic!("Modulus is too short");
}
let _ = bigint_binary(dest, op, n, mbedtls_sys_auto::mpi_mod_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntAddMod(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
let _ = bigint_binary_mod(dest, op1, op2, n, mbedtls_sys_auto::mpi_add_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntSubMod(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
let _ = bigint_binary_mod(dest, op1, op2, n, mbedtls_sys_auto::mpi_sub_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntMulMod(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
let _ = bigint_binary_mod(dest, op1, op2, n, mbedtls_sys_auto::mpi_mul_mpi);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntSquareMod(
dest: *mut TEE_BigInt,
op: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
TEE_BigIntMulMod(dest, op, op, n);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntInvMod(
dest: *mut TEE_BigInt,
op: *const TEE_BigInt,
n: *const TEE_BigInt,
) {
if TEE_BigIntCmpS32(n, 2) < 0 || TEE_BigIntCmpS32(op, 0) == 0 {
panic!("too small modulus or trying to invert zero");
}
unsafe {
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = (*dst_hdr).alloc_size as usize;
let mpi_n = match Mpi::from_teebigint(n) {
Ok(mpi) => mpi,
Err(_) => return, };
let mpi_op = if op == dest as *const TEE_BigInt {
match Mpi::from_teebigint(op) {
Ok(mpi) => mpi,
Err(_) => return, }
} else {
match Mpi::from_teebigint(op) {
Ok(mpi) => mpi,
Err(_) => return, }
};
let mut mpi_dest = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return, };
let result =
mbedtls_sys::mpi_inv_mod((&mut mpi_dest).into(), (&mpi_op).into(), (&mpi_n).into());
if result != 0 {
return; }
let _ = mpi_dest.to_teebigint(dest, alloc_size);
}
}
fn tee_bigint_is_odd(src: *const TEE_BigInt) -> bool {
TEE_BigIntGetBit(src, 0)
}
fn tee_bigint_is_even(src: *const TEE_BigInt) -> bool {
!tee_bigint_is_odd(src)
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntExpMod(
dest: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
n: *const TEE_BigInt,
_context: *const TEE_BigIntFMMContext,
) -> TEE_Result {
if TEE_BigIntCmpS32(n, 2) <= 0 {
panic!("too small modulus");
}
if tee_bigint_is_even(n) {
return TEE_ERROR_OVERFLOW; }
unsafe {
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = (*dst_hdr).alloc_size as usize;
let mpi_n = match Mpi::from_teebigint(n) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let mpi_op1 = if op1 == dest as *const TEE_BigInt {
match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
}
} else {
match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mpi_op2 = if op2 == dest as *const TEE_BigInt {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
}
} else if op2 == op1 {
mpi_op1.clone()
} else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
}
};
let mut mpi_dest = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return TEE_ERROR_OVERFLOW,
};
let result = mbedtls_sys::mpi_exp_mod(
(&mut mpi_dest).into(),
(&mpi_op1).into(),
(&mpi_op2).into(),
(&mpi_n).into(),
core::ptr::null_mut(), );
if result != 0 {
return TEE_ERROR_OVERFLOW;
}
match mpi_dest.to_teebigint(dest, alloc_size) {
Ok(()) => TEE_SUCCESS,
Err(_) => TEE_ERROR_OVERFLOW,
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntRelativePrime(op1: *const TEE_BigInt, op2: *const TEE_BigInt) -> bool {
unsafe {
let mpi_op1 = match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return false, };
let mpi_op2 = if op2 == op1 {
mpi_op1.clone()
} else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return false, }
};
let mut gcd = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return false, };
let result = mbedtls_sys::mpi_gcd((&mut gcd).into(), (&mpi_op1).into(), (&mpi_op2).into());
if result != 0 {
return false; }
match Mpi::new(1) {
Ok(one) => {
match gcd.cmp(&one) {
std::cmp::Ordering::Equal => true, _ => false, }
}
Err(_) => false, }
}
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntComputeExtendedGcd(
gcd: *mut TEE_BigInt,
u: *mut TEE_BigInt,
v: *mut TEE_BigInt,
op1: *const TEE_BigInt,
op2: *const TEE_BigInt,
) {
if gcd.is_null() || op1.is_null() || op2.is_null() {
return;
}
unsafe {
let mpi_op1 = match Mpi::from_teebigint(op1) {
Ok(mpi) => mpi,
Err(_) => return,
};
let mpi_op2 = if op2 == op1 {
mpi_op1.clone()
} else {
match Mpi::from_teebigint(op2) {
Ok(mpi) => mpi,
Err(_) => return,
}
};
if u.is_null() && v.is_null() {
let mut mpi_gcd = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return,
};
let result =
mbedtls_sys::mpi_gcd((&mut mpi_gcd).into(), (&mpi_op1).into(), (&mpi_op2).into());
if result != 0 {
return;
}
let hdr = gcd as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
let _ = mpi_gcd.to_teebigint(gcd, alloc_size);
return;
}
let s1 = mpi_op1.sign();
let s2 = mpi_op2.sign();
let abs_op1 = mpi_abs_value(&mpi_op1);
let abs_op2 = mpi_abs_value(&mpi_op2);
let cmp = abs_op1.cmp(&abs_op2);
let (mpi_gcd, mpi_u, mpi_v) = match cmp {
std::cmp::Ordering::Equal => {
let gcd_result = abs_op1;
let u_result = match Mpi::new(1) {
Ok(mpi) => mpi,
Err(_) => return,
};
let v_result = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return,
};
(gcd_result, u_result, v_result)
}
std::cmp::Ordering::Greater => extended_gcd_algorithm(&abs_op1, &abs_op2),
std::cmp::Ordering::Less => {
let (gcd_result, v_result, u_result) = extended_gcd_algorithm(&abs_op2, &abs_op1);
(gcd_result, u_result, v_result)
}
};
let final_mpi_u = if s1 == mbedtls::bignum::Sign::Negative {
negate_mpi_safe(&mpi_u)
} else {
mpi_u
};
let final_mpi_v = if s2 == mbedtls::bignum::Sign::Negative {
negate_mpi_safe(&mpi_v)
} else {
mpi_v
};
if !u.is_null() {
let hdr = u as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
let _ = final_mpi_u.to_teebigint(u, alloc_size);
}
if !v.is_null() {
let hdr = v as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
let _ = final_mpi_v.to_teebigint(v, alloc_size);
}
let hdr = gcd as *mut BigintHdr;
let alloc_size = (*hdr).alloc_size as usize;
let _ = mpi_gcd.to_teebigint(gcd, alloc_size);
}
}
fn mpi_abs_value(mpi: &Mpi) -> Mpi {
let mut result = mpi.clone();
unsafe {
let handle: *mut mbedtls_sys::mpi = (&mut result).into();
(*handle).s = 1;
}
result
}
fn negate_mpi_safe(mpi: &Mpi) -> Mpi {
let mut result = mpi.clone();
unsafe {
let handle: *mut mbedtls_sys::mpi = (&mut result).into();
(*handle).s *= -1;
}
result
}
fn extended_gcd_algorithm(x: &Mpi, y: &Mpi) -> (Mpi, Mpi, Mpi) {
if let (Ok(x_binary), Ok(y_binary)) = (x.to_binary(), y.to_binary()) {
if x_binary.iter().all(|&b| b == 0) || y_binary.iter().all(|&b| b == 0) {
return (
Mpi::new(0).unwrap_or_else(|_| Mpi::new(0).expect("Failed to create Mpi")),
Mpi::new(0).unwrap_or_else(|_| Mpi::new(0).expect("Failed to create Mpi")),
Mpi::new(0).unwrap_or_else(|_| Mpi::new(0).expect("Failed to create Mpi")),
);
}
}
let mut u = x.clone();
let mut v = y.clone();
let mut a = Mpi::new(1).expect("Failed to create Mpi");
let mut b = Mpi::new(0).expect("Failed to create Mpi");
let mut c = Mpi::new(0).expect("Failed to create Mpi");
let mut d = Mpi::new(1).expect("Failed to create Mpi");
let mut k = 0;
while mpi_is_even(&u) && mpi_is_even(&v) {
k += 1;
u = (&u >> 1).expect("Shift operation failed");
v = (&v >> 1).expect("Shift operation failed");
}
let mut x_copy = u.clone();
let mut y_copy = v.clone();
while !is_mpi_zero(&x_copy) {
while mpi_is_even(&x_copy) {
x_copy = (&x_copy >> 1).expect("Shift operation failed");
if mpi_is_odd(&a) || mpi_is_odd(&b) {
a = add_mpi_safe(&a, &y_copy);
b = sub_mpi_safe(&b, &u);
}
a = (&a >> 1).expect("Shift operation failed");
b = (&b >> 1).expect("Shift operation failed");
}
while mpi_is_even(&y_copy) {
y_copy = (&y_copy >> 1).expect("Shift operation failed");
if mpi_is_odd(&c) || mpi_is_odd(&d) {
c = add_mpi_safe(&c, &y_copy);
d = sub_mpi_safe(&d, &u);
}
c = (&c >> 1).expect("Shift operation failed");
d = (&d >> 1).expect("Shift operation failed");
}
match x_copy.cmp(&y_copy) {
std::cmp::Ordering::Greater | std::cmp::Ordering::Equal => {
x_copy = sub_mpi_safe(&x_copy, &y_copy);
a = sub_mpi_safe(&a, &c);
b = sub_mpi_safe(&b, &d);
}
std::cmp::Ordering::Less => {
y_copy = sub_mpi_safe(&y_copy, &x_copy);
c = sub_mpi_safe(&c, &a);
d = sub_mpi_safe(&d, &b);
}
}
}
let gcd = (&y_copy << k).expect("Shift operation failed");
(gcd, c, d)
}
fn add_mpi_safe(op1: &Mpi, op2: &Mpi) -> Mpi {
let mut result = Mpi::new(0).expect("Failed to create Mpi");
let ret = unsafe { mbedtls_sys::mpi_add_mpi((&mut result).into(), op1.into(), op2.into()) };
if ret == 0 {
result
} else {
Mpi::new(0).expect("Failed to create Mpi")
}
}
fn sub_mpi_safe(op1: &Mpi, op2: &Mpi) -> Mpi {
let mut result = Mpi::new(0).expect("Failed to create Mpi");
let ret = unsafe { mbedtls_sys::mpi_sub_mpi((&mut result).into(), op1.into(), op2.into()) };
if ret == 0 {
result
} else {
Mpi::new(0).expect("Failed to create Mpi")
}
}
fn is_mpi_zero(mpi: &Mpi) -> bool {
match mpi.to_binary() {
Ok(data) => data.iter().all(|&b| b == 0),
Err(_) => true,
}
}
fn mpi_is_even(mpi: &Mpi) -> bool {
match mpi.to_binary() {
Ok(data) => {
if data.is_empty() {
true
} else {
(data[data.len() - 1] & 1) == 0
}
}
Err(_) => true,
}
}
fn mpi_is_odd(mpi: &Mpi) -> bool {
!mpi_is_even(mpi)
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntIsProbablePrime(op: *const TEE_BigInt, confidenceLevel: u32) -> i32 {
if op.is_null() {
return 0;
}
let mpi = match unsafe { Mpi::from_teebigint(op) } {
Ok(mpi) => mpi,
Err(_) => return 0,
};
let rounds = confidenceLevel.max(80);
struct TeeRng;
impl RngCallback for TeeRng {
unsafe extern "C" fn call(
_user_data: *mut mbedtls_sys_auto::types::raw_types::c_void,
data: *mut u8,
len: usize,
) -> i32 {
unsafe {
for i in 0..len {
*data.add(i) = (i % 256) as u8;
}
}
0 }
fn data_ptr(&self) -> *mut mbedtls_sys_auto::types::raw_types::c_void {
core::ptr::null_mut()
}
}
let mut rng = TeeRng;
match mpi.is_probably_prime(rounds, &mut rng) {
Ok(()) => 1, Err(_) => 0, }
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntInitFMM(big_int_fmm: *mut TEE_BigIntFMM, len: usize) {
TEE_BigIntInit(big_int_fmm, len);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntInitFMMContext1(
context: *mut TEE_BigIntFMMContext,
len: usize,
modulus: *const TEE_BigInt,
) -> TEE_Result {
let _ = context;
let _ = len;
let _ = modulus;
TEE_SUCCESS
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntFMMSizeInU32(modulus_size_in_bits: usize) -> usize {
tee_big_int_size_in_u32(modulus_size_in_bits)
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntFMMContextSizeInU32(modulus_size_in_bits: usize) -> usize {
let _ = modulus_size_in_bits; 1
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertToFMM(
dest: *mut TEE_BigIntFMM,
src: *const TEE_BigInt,
n: *const TEE_BigInt,
context: *const TEE_BigIntFMMContext,
) {
let _ = context; TEE_BigIntMod(dest as *mut TEE_BigInt, src, n);
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntConvertFromFMM(
dest: *mut TEE_BigInt,
src: *const TEE_BigIntFMM,
n: *const TEE_BigInt,
context: *const TEE_BigIntFMMContext,
) {
if dest.is_null() || src.is_null() {
return;
}
let mpi_src = match unsafe { Mpi::from_teebigint(src as *const TEE_BigInt) } {
Ok(mpi) => mpi,
Err(_) => return,
};
let hdr = dest as *mut BigintHdr;
let alloc_size = unsafe { (*hdr).alloc_size as usize };
let _ = unsafe { mpi_src.to_teebigint(dest, alloc_size) };
let _ = n;
let _ = context;
}
#[unsafe(no_mangle)]
pub extern "C" fn TEE_BigIntComputeFMM(
dest: *mut TEE_BigIntFMM,
op1: *const TEE_BigIntFMM,
op2: *const TEE_BigIntFMM,
n: *const TEE_BigInt,
context: *const TEE_BigIntFMMContext,
) {
if dest.is_null() || op1.is_null() || op2.is_null() || n.is_null() {
return;
}
let _ = context;
let mpi_op1 = match unsafe { Mpi::from_teebigint(op1 as *const TEE_BigInt) } {
Ok(mpi) => mpi,
Err(_) => return,
};
let mpi_op2 = if op2 as *const TEE_BigInt == op1 as *const TEE_BigInt {
mpi_op1.clone()
} else {
match unsafe { Mpi::from_teebigint(op2 as *const TEE_BigInt) } {
Ok(mpi) => mpi,
Err(_) => return,
}
};
let mpi_n = match unsafe { Mpi::from_teebigint(n) } {
Ok(mpi) => mpi,
Err(_) => return,
};
let mut mpi_t = match Mpi::new(0) {
Ok(mpi) => mpi,
Err(_) => return,
};
let mul_result = unsafe {
mbedtls_sys::mpi_mul_mpi((&mut mpi_t).into(), (&mpi_op1).into(), (&mpi_op2).into())
};
if mul_result != 0 {
return;
}
let dst_hdr = dest as *mut BigintHdr;
let alloc_size = unsafe { (*dst_hdr).alloc_size as usize };
let mod_result = unsafe {
mbedtls_sys::mpi_mod_mpi(
(&mut mpi_t).into(), (&mpi_t).into(), (&mpi_n).into(), )
};
if mod_result != 0 {
return;
}
let _ = unsafe { mpi_t.to_teebigint(dest as *mut TEE_BigInt, alloc_size) };
}